[asr] logfbank with dither (#1179)

* fix logfbank dither

* format
pull/1191/head
Hui Zhang 3 years ago committed by GitHub
parent 6b536e3fb9
commit d852aee2ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,13 +3,6 @@ data:
train_manifest: data/manifest.train train_manifest: data/manifest.train
dev_manifest: data/manifest.dev dev_manifest: data/manifest.dev
test_manifest: data/manifest.test test_manifest: data/manifest.test
min_input_len: 0.5
max_input_len: 20.0 # second
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator: collator:
vocab_filepath: data/lang_char/vocab.txt vocab_filepath: data/lang_char/vocab.txt

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -73,7 +73,7 @@ model:
training: training:
n_epoch: 120 n_epoch: 240
accum_grad: 2 accum_grad: 2
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -28,7 +28,8 @@ def get_baker_data(root_dir):
alignment_files = [f for f in alignment_files if f.stem not in exclude] alignment_files = [f for f in alignment_files if f.stem not in exclude]
data_dict = defaultdict(dict) data_dict = defaultdict(dict)
for alignment_fp in alignment_files: for alignment_fp in alignment_files:
alignment = textgrid.openTextgrid(alignment_fp, includeEmptyIntervals=True) alignment = textgrid.openTextgrid(
alignment_fp, includeEmptyIntervals=True)
# only with baker's annotation # only with baker's annotation
utt_id = alignment.tierNameList[0].split(".")[0] utt_id = alignment.tierNameList[0].split(".")[0]
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList intervals = alignment.tierDict[alignment.tierNameList[0]].entryList

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -5,7 +5,7 @@ process:
n_mels: 80 n_mels: 80
n_shift: 160 n_shift: 160
win_length: 400 win_length: 400
dither: true dither: 0.1
- type: cmvn_json - type: cmvn_json
cmvn_path: data/mean_std.json cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument # these three processes are a.k.a. SpecAugument

@ -584,8 +584,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
hyp_content = hyp[0] hyp_content = hyp[0]
# Prevent the hyp is empty # Prevent the hyp is empty
if len(hyp_content) == 0: if len(hyp_content) == 0:
hyp_content = (self.ctc.blank_id,) hyp_content = (self.ctc.blank_id, )
hyp_content = paddle.to_tensor(hyp_content, place=device, dtype=paddle.long) hyp_content = paddle.to_tensor(
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content) hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id) hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(

@ -319,7 +319,7 @@ class LogMelSpectrogramKaldi():
fmin=20, fmin=20,
fmax=None, fmax=None,
eps=1e-10, eps=1e-10,
dither=False): dither=1.0):
self.fs = fs self.fs = fs
self.n_mels = n_mels self.n_mels = n_mels
self.n_fft = n_fft self.n_fft = n_fft
@ -374,7 +374,7 @@ class LogMelSpectrogramKaldi():
Returns: Returns:
np.ndarray: (T, D) np.ndarray: (T, D)
""" """
dither = self.dither if train else False dither = self.dither if train else 0.0
if x.ndim != 1: if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]") raise ValueError("Not support x: [Time, Channel]")

@ -242,8 +242,7 @@ def train_sp(args, config):
def main(): def main():
# parse args and config and redirect to train_sp # parse args and config and redirect to train_sp
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
description="Train a HiFiGAN model.")
parser.add_argument( parser.add_argument(
"--config", type=str, help="config file to overwrite default config.") "--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.") parser.add_argument("--train-metadata", type=str, help="training data.")

@ -14,191 +14,180 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
""" Create release notes with the issues from a milestone. """ Create release notes with the issues from a milestone.
python3 release_notes.py -c didi delta v.xxxxx python3 release_notes.py -c didi delta v.xxxxx
""" """
import sys
import json
import argparse import argparse
import urllib.request
import collections import collections
import json
import sys
import urllib.request
github_url = 'https://api.github.com/repos' github_url = 'https://api.github.com/repos'
if __name__ == '__main__': if __name__ == '__main__':
# usage: # usage:
# 1. close milestone on github # 1. close milestone on github
# 2. python3 tools/release_notes.py -c didi delta v0.3.3 # 2. python3 tools/release_notes.py -c didi delta v0.3.3
# Parse arguments # Parse arguments
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Create a draft release with the issues from a milestone.', description='Create a draft release with the issues from a milestone.',
) )
parser.add_argument( parser.add_argument(
'user', 'user',
metavar='user', metavar='user',
type=str, type=str,
default='paddlepaddle', default='paddlepaddle',
help='github user: paddlepaddle' help='github user: paddlepaddle')
)
parser.add_argument(
parser.add_argument( 'repository',
'repository', metavar='repository',
metavar='repository', type=str,
type=str, default='paddlespeech',
default='paddlespeech', help='github repository: paddlespeech')
help='github repository: paddlespeech'
) parser.add_argument(
'milestone',
parser.add_argument( metavar='milestone',
'milestone', type=str,
metavar='milestone', help='name of used milestone: v0.3.3')
type=str,
help='name of used milestone: v0.3.3' parser.add_argument(
) '-c',
'--closed',
parser.add_argument( help='Fetch closed milestones/issues',
'-c', '--closed', action='store_true')
help='Fetch closed milestones/issues',
action='store_true' parser.print_help()
) args = parser.parse_args()
parser.print_help() # Fetch milestone infos
args = parser.parse_args() url = "%s/%s/%s/milestones" % (github_url, args.user, args.repository)
# Fetch milestone infos headers = {
url = "%s/%s/%s/milestones" % ( 'Origin':
github_url, 'https://github.com',
args.user, 'User-Agent':
args.repository 'Mozilla/5.0 (X11; Linux x86_64) '
) 'AppleWebKit/537.11 (KHTML, like Gecko) '
'Chrome/23.0.1271.64 Safari/537.11',
headers = { 'Accept':
'Origin': 'https://github.com', 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) ' 'Accept-Charset':
'AppleWebKit/537.11 (KHTML, like Gecko) ' 'ISO-8859-1,utf-8;q=0.7,*;q=0.3',
'Chrome/23.0.1271.64 Safari/537.11', 'Accept-Encoding':
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 'none',
'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.3', 'Accept-Language':
'Accept-Encoding': 'none', 'en-US,en;q=0.8',
'Accept-Language': 'en-US,en;q=0.8', 'Connection':
'Connection': 'keep-alive'} 'keep-alive'
}
if args.closed:
url += "?state=closed" if args.closed:
url += "?state=closed"
req = urllib.request.Request(url, headers=headers)
github_request = urllib.request.urlopen(req) req = urllib.request.Request(url, headers=headers)
if not github_request: github_request = urllib.request.urlopen(req)
parser.error('Cannot read milestone list.') if not github_request:
parser.error('Cannot read milestone list.')
decoder = json.JSONDecoder()
milestones = decoder.decode(github_request.read().decode('utf-8')) decoder = json.JSONDecoder()
github_request.close() milestones = decoder.decode(github_request.read().decode('utf-8'))
github_request.close()
print('parse milestones', file=sys.stderr)
milestone_id = None print('parse milestones', file=sys.stderr)
for milestone in milestones: milestone_id = None
if milestone['title'] == args.milestone: for milestone in milestones:
milestone_id = milestone['number'] if milestone['title'] == args.milestone:
if not milestone_id: milestone_id = milestone['number']
parser.error('Cannot find milestone') if not milestone_id:
parser.error('Cannot find milestone')
# Get milestone related issue info # Get milestone related issue info
url = '%s/%s/%s/issues?milestone=%d' % ( url = '%s/%s/%s/issues?milestone=%d' % (github_url, args.user,
github_url, args.repository, milestone_id)
args.user, if args.closed:
args.repository, url += "&state=closed"
milestone_id
) req = urllib.request.Request(url, headers=headers)
if args.closed: github_request = urllib.request.urlopen(req)
url += "&state=closed" if not github_request:
parser.error('Cannot read issue list.')
req = urllib.request.Request(url, headers=headers)
github_request = urllib.request.urlopen(req) issues = decoder.decode(github_request.read().decode('utf-8'))
if not github_request: github_request.close()
parser.error('Cannot read issue list.')
#print('parse issues', file=sys.stderr)
issues = decoder.decode(github_request.read().decode('utf-8')) #final_data = []
github_request.close() #labels = []
#thanks_to = []
#print('parse issues', file=sys.stderr) #for issue in issues:
#final_data = []
#labels = [] # for label in issue['labels']:
#thanks_to = [] # labels.append(label['name'])
#for issue in issues:
# thanks_to.append('@%s' % (issue['user']['login']))
# for label in issue['labels']: # final_data.append(' * **[%s]** - %s #%d by **@%s**\n' % (
# labels.append(label['name']) # label['name'],
# issue['title'],
# thanks_to.append('@%s' % (issue['user']['login'])) # issue['number'],
# final_data.append(' * **[%s]** - %s #%d by **@%s**\n' % ( # issue['user']['login']
# label['name'], # ))
# issue['title'],
# issue['number'], #dic = collections.defaultdict(set)
# issue['user']['login'] #for l_release in list(set(labels)):
# ))
# for f_data in final_data:
#dic = collections.defaultdict(set) # if '[%s]' % l_release in f_data:
#for l_release in list(set(labels)): # dic[l_release].add(f_data)
# for f_data in final_data: #with open(f"release_note_issues_{args.milestone}.md", 'w') as f:
# if '[%s]' % l_release in f_data: # for key, value in dic.items():
# dic[l_release].add(f_data) # print('# %s\n%s' % (key, ''.join(value)), file=f)
# print('# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' % (' '.join(list(set(thanks_to))))), file=f)
#with open(f"release_note_issues_{args.milestone}.md", 'w') as f:
# for key, value in dic.items(): # Get milestone related PR info
# print('# %s\n%s' % (key, ''.join(value)), file=f) url = '%s/%s/%s/pulls?milestone=%d' % (github_url, args.user,
# print('# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' % (' '.join(list(set(thanks_to))))), file=f) args.repository, milestone_id)
if args.closed:
url += "&state=closed"
# Get milestone related PR info
url = '%s/%s/%s/pulls?milestone=%d' % ( req = urllib.request.Request(url, headers=headers)
github_url, github_request = urllib.request.urlopen(req)
args.user, if not github_request:
args.repository, parser.error('Cannot read issue list.')
milestone_id
) issues = decoder.decode(github_request.read().decode('utf-8'))
if args.closed: github_request.close()
url += "&state=closed"
print('parse pulls', file=sys.stderr)
req = urllib.request.Request(url, headers=headers) final_data = []
github_request = urllib.request.urlopen(req) labels = []
if not github_request: thanks_to = []
parser.error('Cannot read issue list.') for issue in issues:
issues = decoder.decode(github_request.read().decode('utf-8')) for label in issue['labels']:
github_request.close() labels.append(label['name'])
print('parse pulls', file=sys.stderr) thanks_to.append('@%s' % (issue['user']['login']))
final_data = [] final_data.append(' * **[%s]** - %s #%d by **@%s**\n' %
labels = [] (label['name'], issue['title'], issue['number'],
thanks_to = [] issue['user']['login']))
for issue in issues:
dic = collections.defaultdict(set)
for label in issue['labels']: for l_release in list(set(labels)):
labels.append(label['name'])
for f_data in final_data:
thanks_to.append('@%s' % (issue['user']['login'])) if '[%s]' % l_release in f_data:
final_data.append(' * **[%s]** - %s #%d by **@%s**\n' % ( dic[l_release].add(f_data)
label['name'],
issue['title'], with open(f"release_note_pulls_{args.milestone}.md", 'w') as f:
issue['number'], for key, value in dic.items():
issue['user']['login'] print('# %s\n%s' % (key, ''.join(value)), file=f)
)) print(
'# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' %
dic = collections.defaultdict(set) (' '.join(list(set(thanks_to))))),
for l_release in list(set(labels)): file=f)
for f_data in final_data:
if '[%s]' % l_release in f_data:
dic[l_release].add(f_data)
with open(f"release_note_pulls_{args.milestone}.md", 'w') as f:
for key, value in dic.items():
print('# %s\n%s' % (key, ''.join(value)), file=f)
print('# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' % (' '.join(list(set(thanks_to))))), file=f)

Loading…
Cancel
Save