|
|
|
@ -30,24 +30,13 @@ logger = Log(__name__).getlog()
|
|
|
|
|
logger.warn = logger.warning
|
|
|
|
|
|
|
|
|
|
########### hcak paddle #############
|
|
|
|
|
paddle.bool = 'bool'
|
|
|
|
|
paddle.float16 = 'float16'
|
|
|
|
|
paddle.half = 'float16'
|
|
|
|
|
paddle.float32 = 'float32'
|
|
|
|
|
paddle.float = 'float32'
|
|
|
|
|
paddle.float64 = 'float64'
|
|
|
|
|
paddle.double = 'float64'
|
|
|
|
|
paddle.int8 = 'int8'
|
|
|
|
|
paddle.int16 = 'int16'
|
|
|
|
|
paddle.short = 'int16'
|
|
|
|
|
paddle.int32 = 'int32'
|
|
|
|
|
paddle.int = 'int32'
|
|
|
|
|
paddle.int64 = 'int64'
|
|
|
|
|
paddle.long = 'int64'
|
|
|
|
|
paddle.uint8 = 'uint8'
|
|
|
|
|
paddle.uint16 = 'uint16'
|
|
|
|
|
paddle.complex64 = 'complex64'
|
|
|
|
|
paddle.complex128 = 'complex128'
|
|
|
|
|
paddle.cdouble = 'complex128'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -403,45 +392,7 @@ if not hasattr(paddle.nn.functional, 'glu'):
|
|
|
|
|
# return x * 0.5 * (1.0 + paddle.erf(x / math.sqrt(2.0)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# hack loss
|
|
|
|
|
def ctc_loss(logits,
|
|
|
|
|
labels,
|
|
|
|
|
input_lengths,
|
|
|
|
|
label_lengths,
|
|
|
|
|
blank=0,
|
|
|
|
|
reduction='mean',
|
|
|
|
|
norm_by_times=True):
|
|
|
|
|
#logger.info("my ctc loss with norm by times")
|
|
|
|
|
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
|
|
|
|
|
loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,
|
|
|
|
|
input_lengths, label_lengths)
|
|
|
|
|
|
|
|
|
|
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
|
|
|
|
|
assert reduction in ['mean', 'sum', 'none']
|
|
|
|
|
if reduction == 'mean':
|
|
|
|
|
loss_out = paddle.mean(loss_out / label_lengths)
|
|
|
|
|
elif reduction == 'sum':
|
|
|
|
|
loss_out = paddle.sum(loss_out)
|
|
|
|
|
return loss_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.warn(
|
|
|
|
|
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
|
|
|
|
|
)
|
|
|
|
|
F.ctc_loss = ctc_loss
|
|
|
|
|
|
|
|
|
|
########### hcak paddle.nn #############
|
|
|
|
|
if not hasattr(paddle.nn, 'Module'):
|
|
|
|
|
logger.warn("register user Module to paddle.nn, remove this when fixed!")
|
|
|
|
|
setattr(paddle.nn, 'Module', paddle.nn.Layer)
|
|
|
|
|
|
|
|
|
|
# maybe cause assert isinstance(sublayer, core.Layer)
|
|
|
|
|
if not hasattr(paddle.nn, 'ModuleList'):
|
|
|
|
|
logger.warn(
|
|
|
|
|
"register user ModuleList to paddle.nn, remove this when fixed!")
|
|
|
|
|
setattr(paddle.nn, 'ModuleList', paddle.nn.LayerList)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GLU(nn.Layer):
|
|
|
|
|
"""Gated Linear Units (GLU) Layer"""
|
|
|
|
|
|
|
|
|
|