not change ctc grad manual

pull/927/head
Hui Zhang 3 years ago
parent 67555cb906
commit 30499a7654

@ -355,37 +355,7 @@ if not hasattr(paddle.Tensor, 'tolist'):
setattr(paddle.Tensor, 'tolist', tolist)
########### hcak paddle.nn.functional #############
# hack loss
def ctc_loss(logits,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean',
norm_by_times=True):
#logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times,
input_lengths, label_lengths)
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
return loss_out
logger.debug(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
F.ctc_loss = ctc_loss
########### hcak paddle.nn #############
########### hack paddle.nn #############
from paddle.nn import Layer
from typing import Optional
from typing import Mapping
@ -532,3 +502,5 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'LayerDict', LayerDict)

@ -13,6 +13,7 @@
# limitations under the License.
import paddle
from paddle import nn
from typing import Union
from paddle.nn import functional as F
from typeguard import check_argument_types
@ -40,7 +41,7 @@ class CTCDecoderBase(nn.Layer):
dropout_rate: float=0.0,
reduction: bool=True,
batch_average: bool=True,
grad_norm_type: str="instance"):
grad_norm_type: Union[str, None]=None):
"""CTC decoder
Args:
@ -49,7 +50,7 @@ class CTCDecoderBase(nn.Layer):
dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average.
grad_norm_type (str): one of 'instance', 'batch', 'frame', None.
grad_norm_type (str): Default, None. one of 'instance', 'batch', 'frame', None.
"""
assert check_argument_types()
super().__init__()

@ -54,7 +54,7 @@ class CTCLoss(nn.Layer):
self.norm_by_total_logits_len = True
else:
raise ValueError(f"CTCLoss Grad Norm no support {grad_norm_type}")
self.kwargs = {
kwargs = {
"norm_by_times": self.norm_by_times,
"norm_by_batchsize": self.norm_by_batchsize,
"norm_by_total_logits_len": self.norm_by_total_logits_len,
@ -66,10 +66,9 @@ class CTCLoss(nn.Layer):
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
self._kwargs = {k: v for k, v in self.kwargs.items() if k in param}
_notin = {k: v for k, v in self.kwargs.items() if k not in param}
self._kwargs = {k: v for k, v in kwargs.items() if k in param}
_notin = {k: v for k, v in kwargs.items() if k not in param}
logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
#self.loss_fn = partial(self.loss.forward, **_kwargs)
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.
@ -89,8 +88,7 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2])
ys_pad = ys_pad.astype(paddle.int32)
#loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
loss = self.loss(logits, ys_pad, hlens, ys_lens)
loss = self.loss(logits, ys_pad, hlens, ys_lens, **self._kwargs)
if self.batch_average:
# Batch-size average
loss = loss / B

@ -68,7 +68,7 @@ model:
model_conf:
ctc_weight: 0.3
ctc_dropoutrate: 0.0
ctc_grad_norm_type: instance
ctc_grad_norm_type: null
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false

Loading…
Cancel
Save