from typing import Optional
from typing import Union

import paddle
import paddle.nn.functional as F
from paddle.nn.layer.conv import _ConvNd

__all__ = ['Conv2DValid']


class Conv2DValid(_ConvNd):
    """
    Conv2d operator for VALID mode padding.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 stride: int=1,
                 padding: Union[str, int]=0,
                 dilation: int=1,
                 groups: int=1,
                 padding_mode: str='zeros',
                 weight_attr=None,
                 bias_attr=None,
                 data_format="NCHW",
                 valid_trigx: bool=False,
                 valid_trigy: bool=False) -> None:
        super(Conv2DValid, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            False,
            2,
            stride=stride,
            padding=padding,
            padding_mode=padding_mode,
            dilation=dilation,
            groups=groups,
            weight_attr=weight_attr,
            bias_attr=bias_attr,
            data_format=data_format)
        self.valid_trigx = valid_trigx
        self.valid_trigy = valid_trigy

    def _conv_forward(self,
                      input: paddle.Tensor,
                      weight: paddle.Tensor,
                      bias: Optional[paddle.Tensor]):
        validx, validy = 0, 0
        if self.valid_trigx:
            validx = (input.shape[-2] *
                      (self._stride[-2] - 1) - 1 + self._kernel_size[-2]) // 2
        if self.valid_trigy:
            validy = (input.shape[-1] *
                      (self._stride[-1] - 1) - 1 + self._kernel_size[-1]) // 2
        return F.conv2d(input, weight, bias, self._stride, (validx, validy),
                        self._dilation, self._groups)

    def forward(self, input: paddle.Tensor) -> paddle.Tensor:
        return self._conv_forward(input, self.weight, self.bias)