[asr] logfbank with dither ()

* fix logfbank dither

* format
pull/1191/head
Hui Zhang 3 years ago committed by GitHub
parent 6b536e3fb9
commit d852aee2ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,13 +3,6 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
min_input_len: 0.5
max_input_len: 20.0 # second
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/lang_char/vocab.txt

@ -5,7 +5,7 @@ process:
n_mels: 80
n_shift: 160
win_length: 400
dither: true
dither: 0.1
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument

@ -73,7 +73,7 @@ model:
training:
n_epoch: 120
n_epoch: 240
accum_grad: 2
global_grad_clip: 5.0
optim: adam

@ -5,7 +5,7 @@ process:
n_mels: 80
n_shift: 160
win_length: 400
dither: true
dither: 0.1
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument

@ -5,7 +5,7 @@ process:
n_mels: 80
n_shift: 160
win_length: 400
dither: true
dither: 0.1
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument

@ -28,7 +28,8 @@ def get_baker_data(root_dir):
alignment_files = [f for f in alignment_files if f.stem not in exclude]
data_dict = defaultdict(dict)
for alignment_fp in alignment_files:
alignment = textgrid.openTextgrid(alignment_fp, includeEmptyIntervals=True)
alignment = textgrid.openTextgrid(
alignment_fp, includeEmptyIntervals=True)
# only with baker's annotation
utt_id = alignment.tierNameList[0].split(".")[0]
intervals = alignment.tierDict[alignment.tierNameList[0]].entryList

@ -5,7 +5,7 @@ process:
n_mels: 80
n_shift: 160
win_length: 400
dither: true
dither: 0.1
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument

@ -5,7 +5,7 @@ process:
n_mels: 80
n_shift: 160
win_length: 400
dither: true
dither: 0.1
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument

@ -5,7 +5,7 @@ process:
n_mels: 80
n_shift: 160
win_length: 400
dither: true
dither: 0.1
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument

@ -585,7 +585,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
# Prevent the hyp is empty
if len(hyp_content) == 0:
hyp_content = (self.ctc.blank_id, )
hyp_content = paddle.to_tensor(hyp_content, place=device, dtype=paddle.long)
hyp_content = paddle.to_tensor(
hyp_content, place=device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, True, self.ignore_id)
hyps_lens = paddle.to_tensor(

@ -319,7 +319,7 @@ class LogMelSpectrogramKaldi():
fmin=20,
fmax=None,
eps=1e-10,
dither=False):
dither=1.0):
self.fs = fs
self.n_mels = n_mels
self.n_fft = n_fft
@ -374,7 +374,7 @@ class LogMelSpectrogramKaldi():
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else False
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")

@ -242,8 +242,7 @@ def train_sp(args, config):
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(
description="Train a HiFiGAN model.")
parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config.")
parser.add_argument("--train-metadata", type=str, help="training data.")

@ -14,16 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Create release notes with the issues from a milestone.
python3 release_notes.py -c didi delta v.xxxxx
"""
import sys
import json
import argparse
import urllib.request
import collections
import json
import sys
import urllib.request
github_url = 'https://api.github.com/repos'
@ -42,50 +40,51 @@ if __name__ == '__main__':
metavar='user',
type=str,
default='paddlepaddle',
help='github user: paddlepaddle'
)
help='github user: paddlepaddle')
parser.add_argument(
'repository',
metavar='repository',
type=str,
default='paddlespeech',
help='github repository: paddlespeech'
)
help='github repository: paddlespeech')
parser.add_argument(
'milestone',
metavar='milestone',
type=str,
help='name of used milestone: v0.3.3'
)
help='name of used milestone: v0.3.3')
parser.add_argument(
'-c', '--closed',
'-c',
'--closed',
help='Fetch closed milestones/issues',
action='store_true'
)
action='store_true')
parser.print_help()
args = parser.parse_args()
# Fetch milestone infos
url = "%s/%s/%s/milestones" % (
github_url,
args.user,
args.repository
)
url = "%s/%s/%s/milestones" % (github_url, args.user, args.repository)
headers = {
'Origin': 'https://github.com',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) '
'Origin':
'https://github.com',
'User-Agent':
'Mozilla/5.0 (X11; Linux x86_64) '
'AppleWebKit/537.11 (KHTML, like Gecko) '
'Chrome/23.0.1271.64 Safari/537.11',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.3',
'Accept-Encoding': 'none',
'Accept-Language': 'en-US,en;q=0.8',
'Connection': 'keep-alive'}
'Accept':
'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
'Accept-Charset':
'ISO-8859-1,utf-8;q=0.7,*;q=0.3',
'Accept-Encoding':
'none',
'Accept-Language':
'en-US,en;q=0.8',
'Connection':
'keep-alive'
}
if args.closed:
url += "?state=closed"
@ -107,14 +106,9 @@ if __name__ == '__main__':
if not milestone_id:
parser.error('Cannot find milestone')
# Get milestone related issue info
url = '%s/%s/%s/issues?milestone=%d' % (
github_url,
args.user,
args.repository,
milestone_id
)
url = '%s/%s/%s/issues?milestone=%d' % (github_url, args.user,
args.repository, milestone_id)
if args.closed:
url += "&state=closed"
@ -155,14 +149,9 @@ if __name__ == '__main__':
# print('# %s\n%s' % (key, ''.join(value)), file=f)
# print('# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' % (' '.join(list(set(thanks_to))))), file=f)
# Get milestone related PR info
url = '%s/%s/%s/pulls?milestone=%d' % (
github_url,
args.user,
args.repository,
milestone_id
)
url = '%s/%s/%s/pulls?milestone=%d' % (github_url, args.user,
args.repository, milestone_id)
if args.closed:
url += "&state=closed"
@ -184,12 +173,9 @@ if __name__ == '__main__':
labels.append(label['name'])
thanks_to.append('@%s' % (issue['user']['login']))
final_data.append(' * **[%s]** - %s #%d by **@%s**\n' % (
label['name'],
issue['title'],
issue['number'],
issue['user']['login']
))
final_data.append(' * **[%s]** - %s #%d by **@%s**\n' %
(label['name'], issue['title'], issue['number'],
issue['user']['login']))
dic = collections.defaultdict(set)
for l_release in list(set(labels)):
@ -201,4 +187,7 @@ if __name__ == '__main__':
with open(f"release_note_pulls_{args.milestone}.md", 'w') as f:
for key, value in dic.items():
print('# %s\n%s' % (key, ''.join(value)), file=f)
print('# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' % (' '.join(list(set(thanks_to))))), file=f)
print(
'# %s\n%s' % ('Acknowledgements', 'Special thanks to %s ' %
(' '.join(list(set(thanks_to))))),
file=f)

Loading…
Cancel
Save