Merge pull request #1725 from qingen/database-search
[vec] add GRL to domain adaptationpull/1742/head
commit
0186f522af
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
from paddle.autograd import PyLayer
|
||||||
|
|
||||||
|
|
||||||
|
class GradientReversalFunction(PyLayer):
|
||||||
|
"""Gradient Reversal Layer from:
|
||||||
|
Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
|
||||||
|
|
||||||
|
Forward pass is the identity function. In the backward pass,
|
||||||
|
the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, lambda_=1):
|
||||||
|
"""Forward in networks
|
||||||
|
"""
|
||||||
|
ctx.save_for_backward(lambda_)
|
||||||
|
return x.clone()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grads):
|
||||||
|
"""Backward in networks
|
||||||
|
"""
|
||||||
|
lambda_, = ctx.saved_tensor()
|
||||||
|
dx = -lambda_ * grads
|
||||||
|
return dx
|
||||||
|
|
||||||
|
|
||||||
|
class GradientReversalLayer(nn.Layer):
|
||||||
|
"""Gradient Reversal Layer from:
|
||||||
|
Unsupervised Domain Adaptation by Backpropagation (Ganin & Lempitsky, 2015)
|
||||||
|
|
||||||
|
Forward pass is the identity function. In the backward pass,
|
||||||
|
the upstream gradients are multiplied by -lambda (i.e. gradient is reversed)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lambda_=1):
|
||||||
|
super(GradientReversalLayer, self).__init__()
|
||||||
|
self.lambda_ = lambda_
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward in networks
|
||||||
|
"""
|
||||||
|
return GradientReversalFunction.apply(x, self.lambda_)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
|
||||||
|
data = paddle.randn([2, 3], dtype="float64")
|
||||||
|
data.stop_gradient = False
|
||||||
|
grl = GradientReversalLayer(1)
|
||||||
|
out = grl(data)
|
||||||
|
out.mean().backward()
|
||||||
|
print(data.grad)
|
||||||
|
|
||||||
|
data = paddle.randn([2, 3], dtype="float64")
|
||||||
|
data.stop_gradient = False
|
||||||
|
grl = GradientReversalLayer(-1)
|
||||||
|
out = grl(data)
|
||||||
|
out.mean().backward()
|
||||||
|
print(data.grad)
|
Loading…
Reference in new issue