{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "primary-organic",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "stopped-semester",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mask_finished_scores(score: torch.Tensor,\n",
    "                         flag: torch.Tensor) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    If a sequence is finished, we only allow one alive branch. This function\n",
    "    aims to give one branch a zero score and the rest -inf score.\n",
    "    Args:\n",
    "        score (torch.Tensor): A real value array with shape\n",
    "            (batch_size * beam_size, beam_size).\n",
    "        flag (torch.Tensor): A bool array with shape\n",
    "            (batch_size * beam_size, 1).\n",
    "    Returns:\n",
    "        torch.Tensor: (batch_size * beam_size, beam_size).\n",
    "    \"\"\"\n",
    "    beam_size = score.size(-1)\n",
    "    zero_mask = torch.zeros_like(flag, dtype=torch.bool)\n",
    "    if beam_size > 1:\n",
    "        unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),\n",
    "                               dim=1)\n",
    "        finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),\n",
    "                             dim=1)\n",
    "    else:\n",
    "        unfinished = zero_mask\n",
    "        finished = flag\n",
    "    print(unfinished)\n",
    "    print(finished)\n",
    "    score.masked_fill_(unfinished, -float('inf'))\n",
    "    score.masked_fill_(finished, 0)\n",
    "    return score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "agreed-portuguese",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ True],\n",
      "        [False]])\n",
      "tensor([[-0.8841,  0.7381, -0.9986],\n",
      "        [ 0.2675, -0.7971,  0.3798]])\n",
      "tensor([[ True,  True],\n",
      "        [False, False]])\n"
     ]
    }
   ],
   "source": [
    "score = torch.randn((2, 3))\n",
    "flag = torch.ones((2, 1), dtype=torch.bool)\n",
    "flag[1] = False\n",
    "print(flag)\n",
    "print(score)\n",
    "print(flag.repeat([1, 2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "clean-aspect",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[False,  True,  True],\n",
      "        [False, False, False]])\n",
      "tensor([[ True, False, False],\n",
      "        [False, False, False]])\n",
      "tensor([[ 0.0000,    -inf,    -inf],\n",
      "        [ 0.2675, -0.7971,  0.3798]])\n",
      "tensor([[ 0.0000,    -inf,    -inf],\n",
      "        [ 0.2675, -0.7971,  0.3798]])\n"
     ]
    }
   ],
   "source": [
    "r  = mask_finished_scores(score, flag)\n",
    "print(r)\n",
    "print(score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "thrown-airline",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[2, 1], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[True ],\n",
      "        [False]])\n",
      "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[ 2.05994511,  1.87704289,  0.01988174],\n",
      "        [-0.40165186,  0.77547729, -0.64469045]])\n",
      "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[True , True ],\n",
      "        [False, False]])\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "\n",
    "score = paddle.randn((2, 3))\n",
    "flag = paddle.ones((2, 1), dtype='bool')\n",
    "flag[1] = False\n",
    "print(flag)\n",
    "print(score)\n",
    "print(flag.tile([1, 2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "internal-patent",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[False, True , True ],\n",
      "        [False, False, False]])\n",
      "Tensor(shape=[2, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[True , False, False],\n",
      "        [False, False, False]])\n",
      "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[ 2.05994511,  1.87704289,  0.01988174],\n",
      "        [-0.40165186,  0.77547729, -0.64469045]])\n",
      "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[ 2.05994511,  1.87704289,  0.01988174],\n",
      "        [-0.40165186,  0.77547729, -0.64469045]])\n",
      "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[ 2.05994511, -inf.      , -inf.      ],\n",
      "        [-0.40165186,  0.77547729, -0.64469045]])\n",
      "x Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[ 2.05994511, -inf.      , -inf.      ],\n",
      "        [-0.40165186,  0.77547729, -0.64469045]])\n",
      "2 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[ 2.05994511, -inf.      , -inf.      ],\n",
      "        [-0.40165186,  0.77547729, -0.64469045]])\n",
      "3 Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[ 0.        , -inf.      , -inf.      ],\n",
      "        [-0.40165186,  0.77547729, -0.64469045]])\n",
      "Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[ 0.        , -inf.      , -inf.      ],\n",
      "        [-0.40165186,  0.77547729, -0.64469045]])\n"
     ]
    }
   ],
   "source": [
    "paddle.bool = 'bool'\n",
    "\n",
    "def masked_fill(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
    "    print(xs)\n",
    "    trues = paddle.ones_like(xs) * value\n",
    "    assert xs.shape == mask.shape\n",
    "    xs = paddle.where(mask, trues, xs)\n",
    "    return xs\n",
    "\n",
    "def masked_fill_(xs:paddle.Tensor, mask:paddle.Tensor, value:float):\n",
    "    print('x', xs)\n",
    "    trues = paddle.ones_like(xs) * value\n",
    "    assert xs.shape == mask.shape\n",
    "    ret = paddle.where(mask, trues, xs)\n",
    "    print('2', xs)\n",
    "    paddle.assign(ret, output=xs)\n",
    "    print('3', xs)\n",
    "\n",
    "paddle.Tensor.masked_fill = masked_fill\n",
    "paddle.Tensor.masked_fill_ = masked_fill_\n",
    "\n",
    "def mask_finished_scores_pd(score: paddle.Tensor,\n",
    "                         flag: paddle.Tensor) -> paddle.Tensor:\n",
    "    \"\"\"\n",
    "    If a sequence is finished, we only allow one alive branch. This function\n",
    "    aims to give one branch a zero score and the rest -inf score.\n",
    "    Args:\n",
    "        score (torch.Tensor): A real value array with shape\n",
    "            (batch_size * beam_size, beam_size).\n",
    "        flag (torch.Tensor): A bool array with shape\n",
    "            (batch_size * beam_size, 1).\n",
    "    Returns:\n",
    "        torch.Tensor: (batch_size * beam_size, beam_size).\n",
    "    \"\"\"\n",
    "    beam_size = score.shape[-1]\n",
    "    zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)\n",
    "    if beam_size > 1:\n",
    "        unfinished = paddle.concat((zero_mask, flag.tile([1, beam_size - 1])),\n",
    "                               axis=1)\n",
    "        finished = paddle.concat((flag, zero_mask.tile([1, beam_size - 1])),\n",
    "                             axis=1)\n",
    "    else:\n",
    "        unfinished = zero_mask\n",
    "        finished = flag\n",
    "    print(unfinished)\n",
    "    print(finished)\n",
    "    \n",
    "    #score.masked_fill_(unfinished, -float('inf'))\n",
    "    #score.masked_fill_(finished, 0)\n",
    "#     infs = paddle.ones_like(score) * -float('inf')\n",
    "#     score = paddle.where(unfinished, infs, score)\n",
    "#     score = paddle.where(finished, paddle.zeros_like(score), score)\n",
    "\n",
    "#     score = score.masked_fill(unfinished, -float('inf'))\n",
    "#     score = score.masked_fill(finished, 0)\n",
    "    score.masked_fill_(unfinished, -float('inf'))\n",
    "    score.masked_fill_(finished, 0)\n",
    "    return score\n",
    "\n",
    "r  = mask_finished_scores_pd(score, flag)\n",
    "print(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "vocal-prime",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<bound method PyCapsule.value of Tensor(shape=[2, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,\n",
       "       [[ 0.        , -inf.      , -inf.      ],\n",
       "        [-0.40165186,  0.77547729, -0.64469045]])>"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "score.value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "bacterial-adolescent",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Union, Any"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "absent-fiber",
   "metadata": {},
   "outputs": [],
   "source": [
    "def repeat(xs : paddle.Tensor, *size: Any):\n",
    "    print(size)\n",
    "    return paddle.tile(xs, size)\n",
    "paddle.Tensor.repeat = repeat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "material-harbor",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 2)\n",
      "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[True , True ],\n",
      "        [False, False]])\n"
     ]
    }
   ],
   "source": [
    "flag = paddle.ones((2, 1), dtype='bool')\n",
    "flag[1] = False\n",
    "print(flag.repeat(1, 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "acute-brighton",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [1]), 2)\n",
      "Tensor(shape=[2, 2], dtype=bool, place=CUDAPlace(0), stop_gradient=True,\n",
      "       [[True , True ],\n",
      "        [False, False]])\n"
     ]
    }
   ],
   "source": [
    "flag = paddle.ones((2, 1), dtype='bool')\n",
    "flag[1] = False\n",
    "print(flag.repeat(paddle.to_tensor(1), 2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "european-rugby",
   "metadata": {},
   "outputs": [],
   "source": [
    "def size(xs, *args: int):\n",
    "    nargs = len(args)\n",
    "    s = paddle.shape(xs)\n",
    "    assert(nargs <= 1)\n",
    "    if nargs == 1:\n",
    "        return s[args[0]]\n",
    "    else:\n",
    "        return s\n",
    "paddle.Tensor.size = size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "moral-special",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[2], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
       "       [2, 1])"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "flag.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "ahead-coach",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
       "       [1])"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "flag.size(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "incomplete-fitness",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,\n",
       "       [2])"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "flag.size(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "upset-connectivity",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}