# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle import paddle.nn as nn from paddle.autograd import PyLayer class GradientReversalFunction(PyLayer): """Gradient Reversal Layer from: Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) Forward pass is the identity function. In the backward pass, the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) """ @staticmethod def forward(ctx, x, lambda_=1): """Forward in networks """ ctx.save_for_backward(lambda_) return x.clone() @staticmethod def backward(ctx, grads): """Backward in networks """ lambda_, = ctx.saved_tensor() dx = -lambda_ * grads return dx class GradientReversalLayer(nn.Layer): """Gradient Reversal Layer from: Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015) Forward pass is the identity function. In the backward pass, the upstream gradients are multiplied by -lambda (i.e. gradient is reversed) """ def __init__(self, lambda_=1): super(GradientReversalLayer, self).__init__() self.lambda_ = lambda_ def forward(self, x): """Forward in networks """ return GradientReversalFunction.apply(x, self.lambda_) if __name__ == "__main__": paddle.set_device("cpu") data = paddle.randn([2, 3], dtype="float64") data.stop_gradient = False grl = GradientReversalLayer(1) out = grl(data) out.mean().backward() print(data.grad) data = paddle.randn([2, 3], dtype="float64") data.stop_gradient = False grl = GradientReversalLayer(-1) out = grl(data) out.mean().backward() print(data.grad)