remove fixed hack api

pull/756/head
Hui Zhang 3 years ago
parent 86e42f3d21
commit 4b5410eecd

@ -30,24 +30,13 @@ logger = Log(__name__).getlog()
logger.warn = logger.warning logger.warn = logger.warning
########### hcak paddle ############# ########### hcak paddle #############
paddle.bool = 'bool'
paddle.float16 = 'float16'
paddle.half = 'float16' paddle.half = 'float16'
paddle.float32 = 'float32'
paddle.float = 'float32' paddle.float = 'float32'
paddle.float64 = 'float64'
paddle.double = 'float64' paddle.double = 'float64'
paddle.int8 = 'int8'
paddle.int16 = 'int16'
paddle.short = 'int16' paddle.short = 'int16'
paddle.int32 = 'int32'
paddle.int = 'int32' paddle.int = 'int32'
paddle.int64 = 'int64'
paddle.long = 'int64' paddle.long = 'int64'
paddle.uint8 = 'uint8'
paddle.uint16 = 'uint16' paddle.uint16 = 'uint16'
paddle.complex64 = 'complex64'
paddle.complex128 = 'complex128'
paddle.cdouble = 'complex128' paddle.cdouble = 'complex128'
@ -403,45 +392,7 @@ if not hasattr(paddle.nn.functional, 'glu'):
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0))) # return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
# hack loss
def ctc_loss(logits,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean',
norm_by_times=True):
#logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,
input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
return loss_out
logger.warn(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
F.ctc_loss = ctc_loss
########### hcak paddle.nn ############# ########### hcak paddle.nn #############
if not hasattr(paddle.nn, 'Module'):
logger.warn("register user Module to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'Module', paddle.nn.Layer)
# maybe cause assert isinstance(sublayer, core.Layer)
if not hasattr(paddle.nn, 'ModuleList'):
logger.warn(
"register user ModuleList to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'ModuleList', paddle.nn.LayerList)
class GLU(nn.Layer): class GLU(nn.Layer):
"""Gated Linear Units (GLU) Layer""" """Gated Linear Units (GLU) Layer"""

@ -48,7 +48,8 @@ class CTCLoss(nn.Layer):
logits = logits.transpose([1, 0, 2]) logits = logits.transpose([1, 0, 2])
# (TODO:Hui Zhang) ctc loss does not support int64 labels # (TODO:Hui Zhang) ctc loss does not support int64 labels
ys_pad = ys_pad.astype(paddle.int32) ys_pad = ys_pad.astype(paddle.int32)
loss = self.loss(logits, ys_pad, hlens, ys_lens) loss = self.loss(
logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average)
if self.batch_average: if self.batch_average:
# Batch-size average # Batch-size average
loss = loss / B loss = loss / B

Loading…
Cancel
Save