run with aishell/asr3 (#3904)

pull/3918/head
张春乔 1 month ago committed by GitHub
parent 7fd5abd75d
commit 7dc806dc1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -37,8 +37,6 @@ if __name__ == "__main__":
# save asr result to
parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.')
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

@ -104,11 +104,6 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
args = parser.parse_args()
config = CfgNode(new_allowed=True)

@ -714,13 +714,13 @@ class MultiheadAttention(nn.Layer):
else:
if self.beam_size > 1 and bsz == key.size(1):
# key is [T, bsz*beam_size, C], reduce to [T, bsz, C]
key = key.view(
key.size(0), -1, self.beam_size,
key.size(2))[:, :, 0, :]
key = key.reshape(
[key.size(0), -1, self.beam_size,
key.size(2)])[:, :, 0, :]
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(
-1, self.beam_size,
key_padding_mask.size(1))[:, 0, :]
key_padding_mask = key_padding_mask.reshape(
[-1, self.beam_size,
key_padding_mask.size(1)])[:, 0, :]
k = self.k_proj(key)
v = self.v_proj(key)

@ -88,7 +88,7 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
else:
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
out = wav_sum / lengths
out = wav_sum / lengths.astype(wav_sum.dtype)
elif amp_type == "peak":
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)[0]
else:
@ -248,4 +248,4 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
hhpf[pad] += 1
# Adding filters creates notch filter
return (hlpf + hhpf).view(1, -1, 1)
return (hlpf + hhpf).reshape([1, -1, 1])

@ -743,7 +743,7 @@ class SpecAugment(paddle.nn.Layer):
time = x.shape[2]
if time - window <= window:
return x.view(*original_size)
return x.reshape([*original_size])
# compute center and corresponding window
c = paddle.randint(window, time - window, (1, ))[0]
@ -762,7 +762,7 @@ class SpecAugment(paddle.nn.Layer):
x[:, :, :w] = left
x[:, :, w:] = right
return x.view(*original_size)
return x.reshape([*original_size])
def mask_along_axis(self, x, dim):
"""Mask along time or frequency axis.
@ -775,7 +775,7 @@ class SpecAugment(paddle.nn.Layer):
"""
original_size = x.shape
if x.dim() == 4:
x = x.view(-1, x.shape[2], x.shape[3])
x = x.reshape([-1, x.shape[2], x.shape[3]])
batch, time, fea = x.shape
@ -795,7 +795,7 @@ class SpecAugment(paddle.nn.Layer):
(batch, n_mask)).unsqueeze(2)
# compute masks
arange = paddle.arange(end=D).view(1, 1, -1)
arange = paddle.arange(end=D).reshape([1, 1, -1])
mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
mask = mask.any(axis=1)
@ -811,7 +811,7 @@ class SpecAugment(paddle.nn.Layer):
# same to x.masked_fill_(mask, val)
y = paddle.full(x.shape, val, x.dtype)
x = paddle.where(mask, y, x)
return x.view(*original_size)
return x.reshape([*original_size])
class TimeDomainSpecAugment(nn.Layer):

Loading…
Cancel
Save