run with aishell/asr3 (#3904)

pull/3918/head
张春乔 2 months ago committed by GitHub
parent 7fd5abd75d
commit 7dc806dc1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -37,8 +37,6 @@ if __name__ == "__main__":
# save asr result to # save asr result to
parser.add_argument( parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.') '--dict-path', type=str, default=None, help='dict path.')
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args() args = parser.parse_args()
print_arguments(args, globals()) print_arguments(args, globals())

@ -104,11 +104,6 @@ def main(config, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = default_argument_parser() parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
args = parser.parse_args() args = parser.parse_args()
config = CfgNode(new_allowed=True) config = CfgNode(new_allowed=True)

@ -714,13 +714,13 @@ class MultiheadAttention(nn.Layer):
else: else:
if self.beam_size > 1 and bsz == key.size(1): if self.beam_size > 1 and bsz == key.size(1):
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C] # key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
key = key.view( key = key.reshape(
key.size(0), -1, self.beam_size, [key.size(0), -1, self.beam_size,
key.size(2))[:, :, 0, :] key.size(2)])[:, :, 0, :]
if key_padding_mask is not None: if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view( key_padding_mask = key_padding_mask.reshape(
-1, self.beam_size, [-1, self.beam_size,
key_padding_mask.size(1))[:, 0, :] key_padding_mask.size(1)])[:, 0, :]
k = self.k_proj(key) k = self.k_proj(key)
v = self.v_proj(key) v = self.v_proj(key)

@ -88,7 +88,7 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True) out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
else: else:
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True) wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
out = wav_sum / lengths out = wav_sum / lengths.astype(wav_sum.dtype)
elif amp_type == "peak": elif amp_type == "peak":
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)[0] out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)[0]
else: else:
@ -248,4 +248,4 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
hhpf[pad] += 1 hhpf[pad] += 1
# Adding filters creates notch filter # Adding filters creates notch filter
return (hlpf + hhpf).view(1, -1, 1) return (hlpf + hhpf).reshape([1, -1, 1])

@ -743,7 +743,7 @@ class SpecAugment(paddle.nn.Layer):
time = x.shape[2] time = x.shape[2]
if time - window <= window: if time - window <= window:
return x.view(*original_size) return x.reshape([*original_size])
# compute center and corresponding window # compute center and corresponding window
c = paddle.randint(window, time - window, (1, ))[0] c = paddle.randint(window, time - window, (1, ))[0]
@ -762,7 +762,7 @@ class SpecAugment(paddle.nn.Layer):
x[:, :, :w] = left x[:, :, :w] = left
x[:, :, w:] = right x[:, :, w:] = right
return x.view(*original_size) return x.reshape([*original_size])
def mask_along_axis(self, x, dim): def mask_along_axis(self, x, dim):
"""Mask along time or frequency axis. """Mask along time or frequency axis.
@ -775,7 +775,7 @@ class SpecAugment(paddle.nn.Layer):
""" """
original_size = x.shape original_size = x.shape
if x.dim() == 4: if x.dim() == 4:
x = x.view(-1, x.shape[2], x.shape[3]) x = x.reshape([-1, x.shape[2], x.shape[3]])
batch, time, fea = x.shape batch, time, fea = x.shape
@ -795,7 +795,7 @@ class SpecAugment(paddle.nn.Layer):
(batch, n_mask)).unsqueeze(2) (batch, n_mask)).unsqueeze(2)
# compute masks # compute masks
arange = paddle.arange(end=D).view(1, 1, -1) arange = paddle.arange(end=D).reshape([1, 1, -1])
mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len)) mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
mask = mask.any(axis=1) mask = mask.any(axis=1)
@ -811,7 +811,7 @@ class SpecAugment(paddle.nn.Layer):
# same to x.masked_fill_(mask, val) # same to x.masked_fill_(mask, val)
y = paddle.full(x.shape, val, x.dtype) y = paddle.full(x.shape, val, x.dtype)
x = paddle.where(mask, y, x) x = paddle.where(mask, y, x)
return x.view(*original_size) return x.reshape([*original_size])
class TimeDomainSpecAugment(nn.Layer): class TimeDomainSpecAugment(nn.Layer):

Loading…
Cancel
Save