pull/578/head
Hui Zhang 5 years ago
parent 6622032677
commit 0bf8d8aa05

@ -145,7 +145,7 @@ class ConvStack(nn.Layer):
act='brelu') act='brelu')
out_channel = 32 out_channel = 32
self.conv_stack = nn.LayerList([ self.conv_stack = nn.Sequential([
ConvBn( ConvBn(
num_channels_in=32, num_channels_in=32,
num_channels_out=out_channel, num_channels_out=out_channel,

@ -17,9 +17,7 @@ import math
import numpy as np import numpy as np
import distutils.util import distutils.util
__all__ = [ __all__ = ['print_arguments', 'add_arguments', "log_add"]
'print_arguments', 'add_arguments', "log_add", "remove_duplicates_and_blank"
]
def print_arguments(args): def print_arguments(args):
@ -72,26 +70,3 @@ def log_add(args: List[int]) -> float:
a_max = max(args) a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args)) lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp return a_max + lsp
def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
"""ctc alignment to ctc label ids.
"abaa-acee-" -> "abaace"
Args:
hyp (List[int]): hypotheses ids, (L)
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[int]: remove dupicate ids, then remove blank id.
"""
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp

Loading…
Cancel
Save