[asr] logfbank with dither (#1179)

* fix logfbank dither

* format
pull/1191/head
Hui Zhang 3 years ago committed by GitHub
parent 6b536e3fb9
commit d852aee2ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,13 +3,6 @@ data:
train_manifest: data/manifest.train train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test
min_input_len: 0.5
max_input_len: 20.0 # second
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator: collator:
vocab_filepath: data/lang_char/vocab.txt vocab_filepath: data/lang_char/vocab.txt

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -73,7 +73,7 @@ model:
training: training:
n_epoch: 120 n_epoch: 240
accum_grad: 2 accum_grad: 2
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -28,7 +28,8 @@ def get_baker_data(root_dir):
alignment_files = [f for f in alignment_files if f.stem not in exclude] alignment_files = [f for f in alignment_files if f.stem not in exclude]
data_dict = defaultdict(dict) data_dict = defaultdict(dict)
for alignment_fp in alignment_files: for alignment_fp in alignment_files:
alignment = textgrid.openTextgrid(alignment_fp, includeEmptyIntervals=True) alignment = textgrid.openTextgrid(
alignment_fp, includeEmptyIntervals=True)
# only with baker's annotation # only with baker's annotation
utt_id = alignment.tierNameList[0].split(".")[0] utt_id = alignment.tierNameList[0].split(".")[0]
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList intervals = alignment.tierDict[alignment.tierNameList[0]].entryList

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -584,8 +584,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyp_content = hyp[0] hyp_content = hyp[0]
# Prevent the hyp is empty # Prevent the hyp is empty
if len(hyp_content) == 0: if len(hyp_content) == 0:
hyp_content = (self.ctc.blank_id,) hyp_content = (self.ctc.blank_id, )
hyp_content = paddle.to_tensor(hyp_content, place=device, dtype=paddle.long) hyp_content = paddle.to_tensor(
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content) hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id) hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(

@ -319,7 +319,7 @@ class LogMelSpectrogramKaldi():
fmin=20, fmin=20,
fmax=None, fmax=None,
eps=1e-10, eps=1e-10,
dither=False): dither=1.0):
self.fs = fs self.fs = fs
self.n_mels = n_mels self.n_mels = n_mels
self.n_fft = n_fft self.n_fft = n_fft
@ -374,7 +374,7 @@ class LogMelSpectrogramKaldi():
Returns: Returns:
np.ndarray: (T, D) np.ndarray: (T, D)
""" """
dither = self.dither if train else False dither = self.dither if train else 0.0
if x.ndim != 1: if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]") raise ValueError("Not support x: [Time, Channel]")

@ -242,8 +242,7 @@ def train_sp(args, config):
def main(): def main():
# parse args and config and redirect to train_sp # parse args and config and redirect to train_sp
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
description="Train a HiFiGAN model.")
parser.add_argument( parser.add_argument(
"--config", type=str, help="config file to overwrite default config.") "--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--train-metadata", type=str, help="training data.")

@ -14,16 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" Create release notes with the issues from a milestone. """ Create release notes with the issues from a milestone.
python3 release_notes.py -c didi delta v.xxxxx python3 release_notes.py -c didi delta v.xxxxx
""" """
import sys
import json
import argparse import argparse
import urllib.request
import collections import collections
import json
import sys
import urllib.request
github_url = 'https://api.github.com/repos' github_url = 'https://api.github.com/repos'
@ -42,50 +40,51 @@ if __name__ == '__main__':
metavar='user', metavar='user',
type=str, type=str,
default='paddlepaddle', default='paddlepaddle',
help='github user: paddlepaddle' help='github user: paddlepaddle')
)
parser.add_argument( parser.add_argument(
'repository', 'repository',
metavar='repository', metavar='repository',
type=str, type=str,
default='paddlespeech', default='paddlespeech',
help='github repository: paddlespeech' help='github repository: paddlespeech')
)
parser.add_argument( parser.add_argument(
'milestone', 'milestone',
metavar='milestone', metavar='milestone',
type=str, type=str,
help='name of used milestone: v0.3.3' help='name of used milestone: v0.3.3')
)
parser.add_argument( parser.add_argument(
'-c', '--closed', '-c',
'--closed',
help='Fetch closed milestones/issues', help='Fetch closed milestones/issues',
action='store_true' action='store_true')
)
parser.print_help() parser.print_help()
args = parser.parse_args() args = parser.parse_args()
# Fetch milestone infos # Fetch milestone infos
url = "%s/%s/%s/milestones" % ( url = "%s/%s/%s/milestones" % (github_url, args.user, args.repository)
github_url,
args.user,
args.repository
)
headers = { headers = {
'Origin': 'https://github.com', 'Origin':
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) ' 'https://github.com',
'User-Agent':
'Mozilla/5.0 (X11; Linux x86_64) '
'AppleWebKit/537.11 (KHTML, like Gecko) ' 'AppleWebKit/537.11 (KHTML, like Gecko) '
'Chrome/23.0.1271.64 Safari/537.11', 'Chrome/23.0.1271.64 Safari/537.11',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 'Accept':
'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.3', 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
'Accept-Encoding': 'none', 'Accept-Charset':
'Accept-Language': 'en-US,en;q=0.8', 'ISO-8859-1,utf-8;q=0.7,*;q=0.3',
'Connection': 'keep-alive'} 'Accept-Encoding':
'none',
'Accept-Language':
'en-US,en;q=0.8',
'Connection':
'keep-alive'
}
if args.closed: if args.closed:
url += "?state=closed" url += "?state=closed"
@ -107,14 +106,9 @@ if __name__ == '__main__':
if not milestone_id: if not milestone_id:
parser.error('Cannot find milestone') parser.error('Cannot find milestone')
# Get milestone related issue info # Get milestone related issue info
url = '%s/%s/%s/issues?milestone=%d' % ( url = '%s/%s/%s/issues?milestone=%d' % (github_url, args.user,
github_url, args.repository, milestone_id)
args.user,
args.repository,
milestone_id
)
if args.closed: if args.closed:
url += "&state=closed" url += "&state=closed"
@ -155,14 +149,9 @@ if __name__ == '__main__':
# print('# %s\n%s' % (key, ''.join(value)), file=f) # print('# %s\n%s' % (key, ''.join(value)), file=f)
# print('# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' % (' '.join(list(set(thanks_to))))), file=f) # print('# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' % (' '.join(list(set(thanks_to))))), file=f)
# Get milestone related PR info # Get milestone related PR info
url = '%s/%s/%s/pulls?milestone=%d' % ( url = '%s/%s/%s/pulls?milestone=%d' % (github_url, args.user,
github_url, args.repository, milestone_id)
args.user,
args.repository,
milestone_id
)
if args.closed: if args.closed:
url += "&state=closed" url += "&state=closed"
@ -184,12 +173,9 @@ if __name__ == '__main__':
labels.append(label['name']) labels.append(label['name'])
thanks_to.append('@%s' % (issue['user']['login'])) thanks_to.append('@%s' % (issue['user']['login']))
final_data.append(' * **[%s]** - %s #%d by **@%s**\n' % ( final_data.append(' * **[%s]** - %s #%d by **@%s**\n' %
label['name'], (label['name'], issue['title'], issue['number'],
issue['title'], issue['user']['login']))
issue['number'],
issue['user']['login']
))
dic = collections.defaultdict(set) dic = collections.defaultdict(set)
for l_release in list(set(labels)): for l_release in list(set(labels)):
@ -201,4 +187,7 @@ if __name__ == '__main__':
with open(f"release_note_pulls_{args.milestone}.md", 'w') as f: with open(f"release_note_pulls_{args.milestone}.md", 'w') as f:
for key, value in dic.items(): for key, value in dic.items():
print('# %s\n%s' % (key, ''.join(value)), file=f) print('# %s\n%s' % (key, ''.join(value)), file=f)
print('# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' % (' '.join(list(set(thanks_to))))), file=f) print(
'# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' %
(' '.join(list(set(thanks_to))))),
file=f)

Loading…
Cancel
Save