wav2vec2 pipeline

pull/2374/head
tianhao zhang 3 years ago
parent 0975a332c4
commit c51da12b7f

@ -14,3 +14,9 @@
import _locale import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -14,12 +14,12 @@
from . import compliance from . import compliance
from . import datasets from . import datasets
from . import features from . import features
from . import text
from . import transform
from . import streamdata
from . import functional from . import functional
from . import io from . import io
from . import metric from . import metric
from . import sox_effects from . import sox_effects
from . import streamdata
from . import text
from . import transform
from .backends import load from .backends import load
from .backends import save from .backends import save

@ -4,66 +4,67 @@
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
# #
# flake8: noqa # flake8: noqa
from .cache import cached_tarfile_samples
from .cache import cached_tarfile_to_samples from .cache import (
from .cache import lru_cleanup cached_tarfile_samples,
from .cache import pipe_cleaner cached_tarfile_to_samples,
from .compat import FluidWrapper lru_cleanup,
from .compat import WebDataset pipe_cleaner,
from .compat import WebLoader )
from .extradatasets import MockDataset from .compat import WebDataset, WebLoader, FluidWrapper
from .extradatasets import with_epoch from .extradatasets import MockDataset, with_epoch, with_length
from .extradatasets import with_length from .filters import (
from .filters import associate associate,
from .filters import audio_cmvn batched,
from .filters import audio_compute_fbank decode,
from .filters import audio_data_filter detshuffle,
from .filters import audio_padding extract_keys,
from .filters import audio_resample getfirst,
from .filters import audio_spec_aug info,
from .filters import audio_tokenize map,
from .filters import batched map_dict,
from .filters import decode map_tuple,
from .filters import detshuffle pipelinefilter,
from .filters import extract_keys rename,
from .filters import getfirst rename_keys,
from .filters import info audio_resample,
from .filters import map select,
from .filters import map_dict shuffle,
from .filters import map_tuple slice,
from .filters import pipelinefilter to_tuple,
from .filters import placeholder transform_with,
from .filters import rename unbatched,
from .filters import rename_keys xdecode,
from .filters import select audio_data_filter,
from .filters import shuffle audio_tokenize,
from .filters import slice audio_resample,
from .filters import sort audio_compute_fbank,
from .filters import to_tuple audio_spec_aug,
from .filters import transform_with sort,
from .filters import unbatched audio_padding,
from .filters import xdecode audio_cmvn,
from .handlers import ignore_and_continue placeholder,
from .handlers import ignore_and_stop )
from .handlers import reraise_exception from .handlers import (
from .handlers import warn_and_continue ignore_and_continue,
from .handlers import warn_and_stop ignore_and_stop,
from .mix import RandomMix reraise_exception,
from .mix import RoundRobin warn_and_continue,
warn_and_stop,
)
from .pipeline import DataPipeline from .pipeline import DataPipeline
from .shardlists import MultiShardSample from .shardlists import (
from .shardlists import non_empty MultiShardSample,
from .shardlists import resampled ResampledShards,
from .shardlists import ResampledShards SimpleShardList,
from .shardlists import shardspec non_empty,
from .shardlists import SimpleShardList resampled,
from .shardlists import single_node_only shardspec,
from .shardlists import split_by_node single_node_only,
from .shardlists import split_by_worker split_by_node,
from .tariterators import tarfile_samples split_by_worker,
from .tariterators import tarfile_to_samples )
from .utils import PipelineStage from .tariterators import tarfile_samples, tarfile_to_samples
from .utils import repeatedly from .utils import PipelineStage, repeatedly
from .writer import numpy_dumps from .writer import ShardWriter, TarWriter, numpy_dumps
from .writer import ShardWriter from .mix import RandomMix, RoundRobin
from .writer import TarWriter

@ -5,19 +5,18 @@
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
# #
"""Automatically decode webdataset samples.""" """Automatically decode webdataset samples."""
import io
import json import io, json, os, pickle, re, tempfile
import os
import pickle
import re
import tempfile
from functools import partial from functools import partial
import numpy as np import numpy as np
"""Extensions passed on to the image decoder.""" """Extensions passed on to the image decoder."""
image_extensions = "jpg jpeg png ppm pgm pbm pnm".split() image_extensions = "jpg jpeg png ppm pgm pbm pnm".split()
################################################################ ################################################################
# handle basic datatypes # handle basic datatypes
################################################################ ################################################################
@ -129,7 +128,7 @@ def call_extension_handler(key, data, f, extensions):
target = target.split(".") target = target.split(".")
if len(target) > len(extension): if len(target) > len(extension):
continue continue
if extension[-len(target):] == target: if extension[-len(target) :] == target:
return f(data) return f(data)
return None return None
@ -269,6 +268,7 @@ def imagehandler(imagespec, extensions=image_extensions):
################################################################ ################################################################
# torch video # torch video
################################################################ ################################################################
''' '''
def torch_video(key, data): def torch_video(key, data):
"""Decode video using the torchvideo library. """Decode video using the torchvideo library.
@ -289,6 +289,7 @@ def torch_video(key, data):
return torchvision.io.read_video(fname, pts_unit="sec") return torchvision.io.read_video(fname, pts_unit="sec")
''' '''
################################################################ ################################################################
# paddlespeech.audio # paddlespeech.audio
################################################################ ################################################################
@ -358,6 +359,7 @@ def gzfilter(key, data):
# decode entire training amples # decode entire training amples
################################################################ ################################################################
default_pre_handlers = [gzfilter] default_pre_handlers = [gzfilter]
default_post_handlers = [basichandlers] default_post_handlers = [basichandlers]
@ -385,8 +387,7 @@ class Decoder:
pre = default_pre_handlers pre = default_pre_handlers
if post is None: if post is None:
post = default_post_handlers post = default_post_handlers
assert all(callable(h) assert all(callable(h) for h in handlers), f"one of {handlers} not callable"
for h in handlers), f"one of {handlers} not callable"
assert all(callable(h) for h in pre), f"one of {pre} not callable" assert all(callable(h) for h in pre), f"one of {pre} not callable"
assert all(callable(h) for h in post), f"one of {post} not callable" assert all(callable(h) for h in post), f"one of {post} not callable"
self.handlers = pre + handlers + post self.handlers = pre + handlers + post

@ -2,10 +2,7 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
import os import itertools, os, random, re, sys
import random
import re
import sys
from urllib.parse import urlparse from urllib.parse import urlparse
from . import filters from . import filters
@ -43,7 +40,7 @@ def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
os.remove(fname) os.remove(fname)
def download(url, dest, chunk_size=1024**2, verbose=False): def download(url, dest, chunk_size=1024 ** 2, verbose=False):
"""Download a file from `url` to `dest`.""" """Download a file from `url` to `dest`."""
temp = dest + f".temp{os.getpid()}" temp = dest + f".temp{os.getpid()}"
with gopen.gopen(url) as stream: with gopen.gopen(url) as stream:
@ -72,7 +69,8 @@ def get_file_cached(
cache_size=-1, cache_size=-1,
cache_dir=None, cache_dir=None,
url_to_name=pipe_cleaner, url_to_name=pipe_cleaner,
verbose=False, ): verbose=False,
):
if cache_size == -1: if cache_size == -1:
cache_size = default_cache_size cache_size = default_cache_size
if cache_dir is None: if cache_dir is None:
@ -116,7 +114,8 @@ def cached_url_opener(
url_to_name=pipe_cleaner, url_to_name=pipe_cleaner,
validator=check_tar_format, validator=check_tar_format,
verbose=False, verbose=False,
always=False, ): always=False,
):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
verbose = verbose or verbose_cache verbose = verbose or verbose_cache
for sample in data: for sample in data:
@ -133,7 +132,8 @@ def cached_url_opener(
cache_size=cache_size, cache_size=cache_size,
cache_dir=cache_dir, cache_dir=cache_dir,
url_to_name=url_to_name, url_to_name=url_to_name,
verbose=verbose, ) verbose=verbose,
)
if verbose: if verbose:
print("# opening %s" % dest, file=sys.stderr) print("# opening %s" % dest, file=sys.stderr)
assert os.path.exists(dest) assert os.path.exists(dest)
@ -143,8 +143,9 @@ def cached_url_opener(
data = f.read(200) data = f.read(200)
os.remove(dest) os.remove(dest)
raise ValueError( raise ValueError(
"%s (%s) is not a tar archive, but a %s, contains %s" % "%s (%s) is not a tar archive, but a %s, contains %s"
(dest, url, ftype, repr(data))) % (dest, url, ftype, repr(data))
)
try: try:
stream = open(dest, "rb") stream = open(dest, "rb")
sample.update(stream=stream) sample.update(stream=stream)
@ -157,7 +158,7 @@ def cached_url_opener(
continue continue
raise exn raise exn
except Exception as exn: except Exception as exn:
exn.args = exn.args + (url, ) exn.args = exn.args + (url,)
if handler(exn): if handler(exn):
continue continue
else: else:
@ -171,7 +172,8 @@ def cached_tarfile_samples(
cache_dir=None, cache_dir=None,
verbose=False, verbose=False,
url_to_name=pipe_cleaner, url_to_name=pipe_cleaner,
always=False, ): always=False,
):
streams = cached_url_opener( streams = cached_url_opener(
src, src,
handler=handler, handler=handler,
@ -179,7 +181,8 @@ def cached_tarfile_samples(
cache_dir=cache_dir, cache_dir=cache_dir,
verbose=verbose, verbose=verbose,
url_to_name=url_to_name, url_to_name=url_to_name,
always=always, ) always=always,
)
samples = tar_file_and_group_expander(streams, handler=handler) samples = tar_file_and_group_expander(streams, handler=handler)
return samples return samples

@ -2,17 +2,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
import yaml from dataclasses import dataclass
from itertools import islice
from typing import List
import braceexpand, yaml
from . import autodecode from . import autodecode
from . import cache from . import cache, filters, shardlists, tariterators
from . import filters
from . import shardlists
from . import tariterators
from .filters import reraise_exception from .filters import reraise_exception
from .paddle_utils import DataLoader
from .paddle_utils import IterableDataset
from .pipeline import DataPipeline from .pipeline import DataPipeline
from .paddle_utils import DataLoader, IterableDataset
class FluidInterface: class FluidInterface:
@ -26,8 +26,7 @@ class FluidInterface:
return self.compose(filters.unbatched()) return self.compose(filters.unbatched())
def listed(self, batchsize, partial=True): def listed(self, batchsize, partial=True):
return self.compose( return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)
filters.batched(), batchsize=batchsize, collation_fn=None)
def unlisted(self): def unlisted(self):
return self.compose(filters.unlisted()) return self.compose(filters.unlisted())
@ -44,19 +43,9 @@ class FluidInterface:
def map(self, f, handler=reraise_exception): def map(self, f, handler=reraise_exception):
return self.compose(filters.map(f, handler=handler)) return self.compose(filters.map(f, handler=handler))
def decode(self, def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
*args, handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
pre=None, decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
post=None,
only=None,
partial=False,
handler=reraise_exception):
handlers = [
autodecode.ImageHandler(x) if isinstance(x, str) else x
for x in args
]
decoder = autodecode.Decoder(
handlers, pre=pre, post=post, only=only, partial=partial)
return self.map(decoder, handler=handler) return self.map(decoder, handler=handler)
def map_dict(self, handler=reraise_exception, **kw): def map_dict(self, handler=reraise_exception, **kw):
@ -113,7 +102,6 @@ class FluidInterface:
def audio_cmvn(self, cmvn_file): def audio_cmvn(self, cmvn_file):
return self.compose(filters.audio_cmvn(cmvn_file)) return self.compose(filters.audio_cmvn(cmvn_file))
class WebDataset(DataPipeline, FluidInterface): class WebDataset(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline.""" """Small fluid-interface wrapper for DataPipeline."""
@ -128,13 +116,13 @@ class WebDataset(DataPipeline, FluidInterface):
cache_dir=None, cache_dir=None,
detshuffle=False, detshuffle=False,
nodesplitter=shardlists.single_node_only, nodesplitter=shardlists.single_node_only,
verbose=False, ): verbose=False,
):
super().__init__() super().__init__()
if isinstance(urls, IterableDataset): if isinstance(urls, IterableDataset):
assert not resampled assert not resampled
self.append(urls) self.append(urls)
elif isinstance(urls, str) and (urls.endswith(".yaml") or elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
urls.endswith(".yml")):
with (open(urls)) as stream: with (open(urls)) as stream:
spec = yaml.safe_load(stream) spec = yaml.safe_load(stream)
assert "datasets" in spec assert "datasets" in spec
@ -164,7 +152,9 @@ class WebDataset(DataPipeline, FluidInterface):
handler=handler, handler=handler,
verbose=verbose, verbose=verbose,
cache_size=cache_size, cache_size=cache_size,
cache_dir=cache_dir, )) cache_dir=cache_dir,
)
)
class FluidWrapper(DataPipeline, FluidInterface): class FluidWrapper(DataPipeline, FluidInterface):

@ -5,10 +5,20 @@
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
# #
"""Train PyTorch models directly from POSIX tar archive. """Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections. Code works locally or over HTTP connections.
""" """
import itertools as itt
import os
import random
import sys
import braceexpand
from . import utils from . import utils
from .paddle_utils import IterableDataset from .paddle_utils import IterableDataset
from .utils import PipelineStage from .utils import PipelineStage
@ -53,7 +63,8 @@ class repeatedly(IterableDataset, PipelineStage):
return utils.repeatedly( return utils.repeatedly(
source, source,
nepochs=self.nepochs, nepochs=self.nepochs,
nbatches=self.nbatches, ) nbatches=self.nbatches,
)
class with_epoch(IterableDataset): class with_epoch(IterableDataset):

@ -3,6 +3,7 @@
# This file is part of the WebDataset library. # This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# #
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
"""A collection of iterators for data transformations. """A collection of iterators for data transformations.
@ -11,29 +12,28 @@ These functions are plain iterator functions. You can find curried versions
in webdataset.filters, and you can find IterableDataset wrappers in in webdataset.filters, and you can find IterableDataset wrappers in
webdataset.processing. webdataset.processing.
""" """
import io import io
import itertools
import os
import random
import re
import sys
import time
from fnmatch import fnmatch from fnmatch import fnmatch
from functools import reduce import re
import itertools, os, random, sys, time
from functools import reduce, wraps
import paddle import numpy as np
from . import autodecode from . import autodecode
from . import utils from . import utils
from .paddle_utils import PaddleTensor
from .utils import PipelineStage
from .. import backends from .. import backends
from ..compliance import kaldi from ..compliance import kaldi
import paddle
from ..transform.cmvn import GlobalCMVN from ..transform.cmvn import GlobalCMVN
from ..transform.spec_augment import freq_mask
from ..transform.spec_augment import time_mask
from ..transform.spec_augment import time_warp
from ..utils.tensor_utils import pad_sequence from ..utils.tensor_utils import pad_sequence
from .utils import PipelineStage from ..transform.spec_augment import time_warp
from ..transform.spec_augment import time_mask
from ..transform.spec_augment import freq_mask
class FilterFunction(object): class FilterFunction(object):
"""Helper class for currying pipeline stages. """Helper class for currying pipeline stages.
@ -159,12 +159,10 @@ def transform_with(sample, transformers):
result[i] = f(sample[i]) result[i] = f(sample[i])
return result return result
### ###
# Iterators # Iterators
### ###
def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""): def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""):
"""Print information about the samples that are passing through. """Print information about the samples that are passing through.
@ -280,16 +278,10 @@ def _log_keys(data, logfile=None):
log_keys = pipelinefilter(_log_keys) log_keys = pipelinefilter(_log_keys)
def _minedecode(x):
if isinstance(x, str):
return autodecode.imagehandler(x)
else:
return x
def _decode(data, *args, handler=reraise_exception, **kw): def _decode(data, *args, handler=reraise_exception, **kw):
"""Decode data based on the decoding functions given as arguments.""" """Decode data based on the decoding functions given as arguments."""
decoder = _minedecode
decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x
handlers = [decoder(x) for x in args] handlers = [decoder(x) for x in args]
f = autodecode.Decoder(handlers, **kw) f = autodecode.Decoder(handlers, **kw)
@ -333,24 +325,15 @@ def _rename(data, handler=reraise_exception, keep=True, **kw):
for sample in data: for sample in data:
try: try:
if not keep: if not keep:
yield { yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}
k: getfirst(sample, v, missing_is_error=True)
for k, v in kw.items()
}
else: else:
def listify(v): def listify(v):
return v.split(";") if isinstance(v, str) else v return v.split(";") if isinstance(v, str) else v
to_be_replaced = {x for v in kw.values() for x in listify(v)} to_be_replaced = {x for v in kw.values() for x in listify(v)}
result = { result = {k: v for k, v in sample.items() if k not in to_be_replaced}
k: v result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()})
for k, v in sample.items() if k not in to_be_replaced
}
result.update({
k: getfirst(sample, v, missing_is_error=True)
for k, v in kw.items()
})
yield result yield result
except Exception as exn: except Exception as exn:
if handler(exn): if handler(exn):
@ -398,11 +381,7 @@ def _map_dict(data, handler=reraise_exception, **kw):
map_dict = pipelinefilter(_map_dict) map_dict = pipelinefilter(_map_dict)
def _to_tuple(data, def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None):
*args,
handler=reraise_exception,
missing_is_error=True,
none_is_error=None):
"""Convert dict samples to tuples.""" """Convert dict samples to tuples."""
if none_is_error is None: if none_is_error is None:
none_is_error = missing_is_error none_is_error = missing_is_error
@ -411,10 +390,7 @@ def _to_tuple(data,
for sample in data: for sample in data:
try: try:
result = tuple([ result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args])
getfirst(sample, f, missing_is_error=missing_is_error)
for f in args
])
if none_is_error and any(x is None for x in result): if none_is_error and any(x is None for x in result):
raise ValueError(f"to_tuple {args} got {sample.keys()}") raise ValueError(f"to_tuple {args} got {sample.keys()}")
yield result yield result
@ -487,28 +463,19 @@ rsample = pipelinefilter(_rsample)
slice = pipelinefilter(itertools.islice) slice = pipelinefilter(itertools.islice)
def _extract_keys(source, def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False):
*patterns,
duplicate_is_error=True,
ignore_missing=False):
for sample in source: for sample in source:
result = [] result = []
for pattern in patterns: for pattern in patterns:
pattern = pattern.split(";") if isinstance(pattern, pattern = pattern.split(";") if isinstance(pattern, str) else pattern
str) else pattern matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)]
matches = [
x for x in sample.keys()
if any(fnmatch("." + x, p) for p in pattern)
]
if len(matches) == 0: if len(matches) == 0:
if ignore_missing: if ignore_missing:
continue continue
else: else:
raise ValueError( raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.")
f"Cannot find {pattern} in sample keys {sample.keys()}.")
if len(matches) > 1 and duplicate_is_error: if len(matches) > 1 and duplicate_is_error:
raise ValueError( raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.")
f"Multiple sample keys {sample.keys()} match {pattern}.")
value = sample[matches[0]] value = sample[matches[0]]
result.append(value) result.append(value)
yield tuple(result) yield tuple(result)
@ -517,12 +484,7 @@ def _extract_keys(source,
extract_keys = pipelinefilter(_extract_keys) extract_keys = pipelinefilter(_extract_keys)
def _rename_keys(source, def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw):
*args,
keep_unselected=False,
must_match=True,
duplicate_is_error=True,
**kw):
renamings = [(pattern, output) for output, pattern in args] renamings = [(pattern, output) for output, pattern in args]
renamings += [(pattern, output) for output, pattern in kw.items()] renamings += [(pattern, output) for output, pattern in kw.items()]
for sample in source: for sample in source:
@ -542,15 +504,11 @@ def _rename_keys(source,
continue continue
if new_name in new_sample: if new_name in new_sample:
if duplicate_is_error: if duplicate_is_error:
raise ValueError( raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.")
f"Duplicate value in sample {sample.keys()} after rename."
)
continue continue
new_sample[new_name] = value new_sample[new_name] = value
if must_match and not all(matched.values()): if must_match and not all(matched.values()):
raise ValueError( raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).")
f"Not all patterns ({matched}) matched sample keys ({sample.keys()})."
)
yield new_sample yield new_sample
@ -583,8 +541,7 @@ def find_decoder(decoders, path):
if fname.startswith("__"): if fname.startswith("__"):
return lambda x: x return lambda x: x
for pattern, fun in decoders[::-1]: for pattern, fun in decoders[::-1]:
if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern):
pattern):
return fun return fun
return None return None
@ -594,7 +551,8 @@ def _xdecode(
*args, *args,
must_decode=True, must_decode=True,
defaults=default_decoders, defaults=default_decoders,
**kw, ): **kw,
):
decoders = list(defaults) + list(args) decoders = list(defaults) + list(args)
decoders += [("*." + k, v) for k, v in kw.items()] decoders += [("*." + k, v) for k, v in kw.items()]
for sample in source: for sample in source:
@ -617,10 +575,10 @@ def _xdecode(
new_sample[path] = value new_sample[path] = value
yield new_sample yield new_sample
xdecode = pipelinefilter(_xdecode) xdecode = pipelinefilter(_xdecode)
def _audio_data_filter(source, def _audio_data_filter(source,
frame_shift=10, frame_shift=10,
max_length=10240, max_length=10240,
@ -655,8 +613,7 @@ def _audio_data_filter(source,
assert 'wav' in sample assert 'wav' in sample
assert 'label' in sample assert 'label' in sample
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default) # sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
num_frames = sample['wav'].shape[1] / sample['sample_rate'] * ( num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift)
1000 / frame_shift)
if num_frames < min_length: if num_frames < min_length:
continue continue
if num_frames > max_length: if num_frames > max_length:
@ -672,10 +629,8 @@ def _audio_data_filter(source,
continue continue
yield sample yield sample
audio_data_filter = pipelinefilter(_audio_data_filter) audio_data_filter = pipelinefilter(_audio_data_filter)
def _audio_tokenize(source, def _audio_tokenize(source,
symbol_table, symbol_table,
bpe_model=None, bpe_model=None,
@ -738,10 +693,8 @@ def _audio_tokenize(source,
sample['label'] = label sample['label'] = label
yield sample yield sample
audio_tokenize = pipelinefilter(_audio_tokenize) audio_tokenize = pipelinefilter(_audio_tokenize)
def _audio_resample(source, resample_rate=16000): def _audio_resample(source, resample_rate=16000):
""" Resample data. """ Resample data.
Inplace operation. Inplace operation.
@ -760,17 +713,13 @@ def _audio_resample(source, resample_rate=16000):
waveform = sample['wav'] waveform = sample['wav']
if sample_rate != resample_rate: if sample_rate != resample_rate:
sample['sample_rate'] = resample_rate sample['sample_rate'] = resample_rate
sample['wav'] = paddle.to_tensor( sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample(
backends.soundfile_backend.resample( waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate
waveform.numpy(), ))
src_sr=sample_rate,
target_sr=resample_rate))
yield sample yield sample
audio_resample = pipelinefilter(_audio_resample) audio_resample = pipelinefilter(_audio_resample)
def _audio_compute_fbank(source, def _audio_compute_fbank(source,
num_mel_bins=80, num_mel_bins=80,
frame_length=25, frame_length=25,
@ -797,8 +746,7 @@ def _audio_compute_fbank(source,
waveform = sample['wav'] waveform = sample['wav']
waveform = waveform * (1 << 15) waveform = waveform * (1 << 15)
# Only keep fname, feat, label # Only keep fname, feat, label
mat = kaldi.fbank( mat = kaldi.fbank(waveform,
waveform,
n_mels=num_mel_bins, n_mels=num_mel_bins,
frame_length=frame_length, frame_length=frame_length,
frame_shift=frame_shift, frame_shift=frame_shift,
@ -810,9 +758,7 @@ def _audio_compute_fbank(source,
audio_compute_fbank = pipelinefilter(_audio_compute_fbank) audio_compute_fbank = pipelinefilter(_audio_compute_fbank)
def _audio_spec_aug(source,
def _audio_spec_aug(
source,
max_w=5, max_w=5,
w_inplace=True, w_inplace=True,
w_mode="PIL", w_mode="PIL",
@ -823,7 +769,7 @@ def _audio_spec_aug(
max_t=40, max_t=40,
num_t_mask=2, num_t_mask=2,
t_inplace=True, t_inplace=True,
t_replace_with_zero=False, ): t_replace_with_zero=False,):
""" Do spec augmentation """ Do spec augmentation
Inplace operation Inplace operation
@ -847,23 +793,12 @@ def _audio_spec_aug(
for sample in source: for sample in source:
x = sample['feat'] x = sample['feat']
x = x.numpy() x = x.numpy()
x = time_warp(x, max_time_warp=max_w, inplace=w_inplace, mode=w_mode) x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode)
x = freq_mask( x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero)
x, x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero)
F=max_f,
n_mask=num_f_mask,
inplace=f_inplace,
replace_with_zero=f_replace_with_zero)
x = time_mask(
x,
T=max_t,
n_mask=num_t_mask,
inplace=t_inplace,
replace_with_zero=t_replace_with_zero)
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
yield sample yield sample
audio_spec_aug = pipelinefilter(_audio_spec_aug) audio_spec_aug = pipelinefilter(_audio_spec_aug)
@ -894,10 +829,8 @@ def _sort(source, sort_size=500):
for x in buf: for x in buf:
yield x yield x
sort = pipelinefilter(_sort) sort = pipelinefilter(_sort)
def _batched(source, batch_size=16): def _batched(source, batch_size=16):
""" Static batch the data by `batch_size` """ Static batch the data by `batch_size`
@ -917,10 +850,8 @@ def _batched(source, batch_size=16):
if len(buf) > 0: if len(buf) > 0:
yield buf yield buf
batched = pipelinefilter(_batched) batched = pipelinefilter(_batched)
def dynamic_batched(source, max_frames_in_batch=12000): def dynamic_batched(source, max_frames_in_batch=12000):
""" Dynamic batch the data until the total frames in batch """ Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch` reach `max_frames_in_batch`
@ -961,8 +892,8 @@ def _audio_padding(source):
""" """
for sample in source: for sample in source:
assert isinstance(sample, list) assert isinstance(sample, list)
feats_length = paddle.to_tensor( feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample],
[x['feat'].shape[0] for x in sample], dtype="int64") dtype="int64")
order = paddle.argsort(feats_length, descending=True) order = paddle.argsort(feats_length, descending=True)
feats_lengths = paddle.to_tensor( feats_lengths = paddle.to_tensor(
[sample[i]['feat'].shape[0] for i in order], dtype="int64") [sample[i]['feat'].shape[0] for i in order], dtype="int64")
@ -971,20 +902,20 @@ def _audio_padding(source):
sorted_labels = [ sorted_labels = [
paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order
] ]
label_lengths = paddle.to_tensor( label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels],
[x.shape[0] for x in sorted_labels], dtype="int64") dtype="int64")
padded_feats = pad_sequence( padded_feats = pad_sequence(sorted_feats,
sorted_feats, batch_first=True, padding_value=0) batch_first=True,
padding_labels = pad_sequence( padding_value=0)
sorted_labels, batch_first=True, padding_value=-1) padding_labels = pad_sequence(sorted_labels,
batch_first=True,
padding_value=-1)
yield (sorted_keys, padded_feats, feats_lengths, padding_labels, yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths) label_lengths)
audio_padding = pipelinefilter(_audio_padding) audio_padding = pipelinefilter(_audio_padding)
def _audio_cmvn(source, cmvn_file): def _audio_cmvn(source, cmvn_file):
global_cmvn = GlobalCMVN(cmvn_file) global_cmvn = GlobalCMVN(cmvn_file)
for batch in source: for batch in source:
@ -995,13 +926,10 @@ def _audio_cmvn(source, cmvn_file):
yield (sorted_keys, padded_feats, feats_lengths, padding_labels, yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths) label_lengths)
audio_cmvn = pipelinefilter(_audio_cmvn) audio_cmvn = pipelinefilter(_audio_cmvn)
def _placeholder(source): def _placeholder(source):
for data in source: for data in source:
yield data yield data
placeholder = pipelinefilter(_placeholder) placeholder = pipelinefilter(_placeholder)

@ -3,12 +3,12 @@
# This file is part of the WebDataset library. # This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# #
"""Open URLs by calling subcommands.""" """Open URLs by calling subcommands."""
import os
import re import os, sys, re
import sys from subprocess import PIPE, Popen
from subprocess import PIPE
from subprocess import Popen
from urllib.parse import urlparse from urllib.parse import urlparse
# global used for printing additional node information during verbose output # global used for printing additional node information during verbose output
@ -37,7 +37,8 @@ class Pipe:
timeout=7200.0, timeout=7200.0,
ignore_errors=False, ignore_errors=False,
ignore_status=[], ignore_status=[],
**kw, ): **kw,
):
"""Create an IO Pipe.""" """Create an IO Pipe."""
self.ignore_errors = ignore_errors self.ignore_errors = ignore_errors
self.ignore_status = [0] + ignore_status self.ignore_status = [0] + ignore_status
@ -74,7 +75,8 @@ class Pipe:
if verbose: if verbose:
print( print(
f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}", f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}",
file=sys.stderr, ) file=sys.stderr,
)
if self.status not in self.ignore_status and not self.ignore_errors: if self.status not in self.ignore_status and not self.ignore_errors:
raise Exception(f"{self.args}: exit {self.status} (read) {info}") raise Exception(f"{self.args}: exit {self.status} (read) {info}")
@ -112,11 +114,9 @@ class Pipe:
self.close() self.close()
def set_options(obj, def set_options(
timeout=None, obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None
ignore_errors=None, ):
ignore_status=None,
handler=None):
"""Set options for Pipes. """Set options for Pipes.
This function can be called on any stream. It will set pipe options only This function can be called on any stream. It will set pipe options only
@ -168,14 +168,16 @@ def gopen_pipe(url, mode="rb", bufsize=8192):
mode=mode, mode=mode,
shell=True, shell=True,
bufsize=bufsize, bufsize=bufsize,
ignore_status=[141], ) # skipcq: BAN-B604 ignore_status=[141],
) # skipcq: BAN-B604
elif mode[0] == "w": elif mode[0] == "w":
return Pipe( return Pipe(
cmd, cmd,
mode=mode, mode=mode,
shell=True, shell=True,
bufsize=bufsize, bufsize=bufsize,
ignore_status=[141], ) # skipcq: BAN-B604 ignore_status=[141],
) # skipcq: BAN-B604
else: else:
raise ValueError(f"{mode}: unknown mode") raise ValueError(f"{mode}: unknown mode")
@ -194,7 +196,8 @@ def gopen_curl(url, mode="rb", bufsize=8192):
mode=mode, mode=mode,
shell=True, shell=True,
bufsize=bufsize, bufsize=bufsize,
ignore_status=[141, 23], ) # skipcq: BAN-B604 ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w": elif mode[0] == "w":
cmd = f"curl -s -L -T - '{url}'" cmd = f"curl -s -L -T - '{url}'"
return Pipe( return Pipe(
@ -202,7 +205,8 @@ def gopen_curl(url, mode="rb", bufsize=8192):
mode=mode, mode=mode,
shell=True, shell=True,
bufsize=bufsize, bufsize=bufsize,
ignore_status=[141, 26], ) # skipcq: BAN-B604 ignore_status=[141, 26],
) # skipcq: BAN-B604
else: else:
raise ValueError(f"{mode}: unknown mode") raise ValueError(f"{mode}: unknown mode")
@ -222,13 +226,15 @@ def gopen_htgs(url, mode="rb", bufsize=8192):
mode=mode, mode=mode,
shell=True, shell=True,
bufsize=bufsize, bufsize=bufsize,
ignore_status=[141, 23], ) # skipcq: BAN-B604 ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w": elif mode[0] == "w":
raise ValueError(f"{mode}: cannot write") raise ValueError(f"{mode}: cannot write")
else: else:
raise ValueError(f"{mode}: unknown mode") raise ValueError(f"{mode}: unknown mode")
def gopen_gsutil(url, mode="rb", bufsize=8192): def gopen_gsutil(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`. """Open a URL with `curl`.
@ -243,7 +249,8 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
mode=mode, mode=mode,
shell=True, shell=True,
bufsize=bufsize, bufsize=bufsize,
ignore_status=[141, 23], ) # skipcq: BAN-B604 ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == "w": elif mode[0] == "w":
cmd = f"gsutil cp - '{url}'" cmd = f"gsutil cp - '{url}'"
return Pipe( return Pipe(
@ -251,11 +258,13 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
mode=mode, mode=mode,
shell=True, shell=True,
bufsize=bufsize, bufsize=bufsize,
ignore_status=[141, 26], ) # skipcq: BAN-B604 ignore_status=[141, 26],
) # skipcq: BAN-B604
else: else:
raise ValueError(f"{mode}: unknown mode") raise ValueError(f"{mode}: unknown mode")
def gopen_error(url, *args, **kw): def gopen_error(url, *args, **kw):
"""Raise a value error. """Raise a value error.
@ -276,7 +285,8 @@ gopen_schemes = dict(
ftps=gopen_curl, ftps=gopen_curl,
scp=gopen_curl, scp=gopen_curl,
gs=gopen_gsutil, gs=gopen_gsutil,
htgs=gopen_htgs, ) htgs=gopen_htgs,
)
def gopen(url, mode="rb", bufsize=8192, **kw): def gopen(url, mode="rb", bufsize=8192, **kw):

@ -3,6 +3,7 @@
# This file is part of the WebDataset library. # This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# #
"""Pluggable exception handlers. """Pluggable exception handlers.
These are functions that take an exception as an argument and then return... These are functions that take an exception as an argument and then return...
@ -13,8 +14,8 @@ These are functions that take an exception as an argument and then return...
They are used as handler= arguments in much of the library. They are used as handler= arguments in much of the library.
""" """
import time
import warnings import time, warnings
def reraise_exception(exn): def reraise_exception(exn):

@ -5,12 +5,17 @@
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
# #
"""Classes for mixing samples from multiple sources.""" """Classes for mixing samples from multiple sources."""
import random
import itertools, os, random, time, sys
from functools import reduce, wraps
import numpy as np import numpy as np
from .paddle_utils import IterableDataset from . import autodecode, utils
from .paddle_utils import PaddleTensor, IterableDataset
from .utils import PipelineStage
def round_robin_shortest(*sources): def round_robin_shortest(*sources):

@ -5,11 +5,12 @@
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
# #
"""Mock implementations of paddle interfaces when paddle is not available.""" """Mock implementations of paddle interfaces when paddle is not available."""
try: try:
from paddle.io import DataLoader from paddle.io import DataLoader, IterableDataset
from paddle.io import IterableDataset
except ModuleNotFoundError: except ModuleNotFoundError:
class IterableDataset: class IterableDataset:
@ -21,3 +22,12 @@ except ModuleNotFoundError:
"""Empty implementation of DataLoader when paddle is not available.""" """Empty implementation of DataLoader when paddle is not available."""
pass pass
try:
from paddle import Tensor as PaddleTensor
except ModuleNotFoundError:
class TorchTensor:
"""Empty implementation of PaddleTensor when paddle is not available."""
pass

@ -3,12 +3,15 @@
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
#%% #%%
import copy import copy, os, random, sys, time
import sys from dataclasses import dataclass
from itertools import islice from itertools import islice
from typing import List
from .paddle_utils import DataLoader import braceexpand, yaml
from .paddle_utils import IterableDataset
from .handlers import reraise_exception
from .paddle_utils import DataLoader, IterableDataset
from .utils import PipelineStage from .utils import PipelineStage
@ -19,7 +22,8 @@ def add_length_method(obj):
Combined = type( Combined = type(
obj.__class__.__name__ + "_Length", obj.__class__.__name__ + "_Length",
(obj.__class__, IterableDataset), (obj.__class__, IterableDataset),
{"__len__": length}, ) {"__len__": length},
)
obj.__class__ = Combined obj.__class__ = Combined
return obj return obj

@ -4,30 +4,28 @@
# This file is part of the WebDataset library. # This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# #
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
"""Train PyTorch models directly from POSIX tar archive. """Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections. Code works locally or over HTTP connections.
""" """
import os
import random import os, random, sys, time
import sys from dataclasses import dataclass, field
import time
from dataclasses import dataclass
from dataclasses import field
from itertools import islice from itertools import islice
from typing import List from typing import List
import braceexpand import braceexpand, yaml
import yaml
from . import utils from . import utils
from ..utils.log import Logger
from .filters import pipelinefilter from .filters import pipelinefilter
from .paddle_utils import IterableDataset from .paddle_utils import IterableDataset
logger = Logger(__name__)
from ..utils.log import Logger
logger = Logger(__name__)
def expand_urls(urls): def expand_urls(urls):
if isinstance(urls, str): if isinstance(urls, str):
urllist = urls.split("::") urllist = urls.split("::")
@ -66,8 +64,7 @@ class SimpleShardList(IterableDataset):
def split_by_node(src, group=None): def split_by_node(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info( rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
group=group)
logger.info(f"world_size:{world_size}, rank:{rank}") logger.info(f"world_size:{world_size}, rank:{rank}")
if world_size > 1: if world_size > 1:
for s in islice(src, rank, None, world_size): for s in islice(src, rank, None, world_size):
@ -78,11 +75,9 @@ def split_by_node(src, group=None):
def single_node_only(src, group=None): def single_node_only(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info( rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
group=group)
if world_size > 1: if world_size > 1:
raise ValueError( raise ValueError("input pipeline needs to be reconfigured for multinode training")
"input pipeline needs to be reconfigured for multinode training")
for s in src: for s in src:
yield s yield s
@ -109,8 +104,7 @@ def resampled_(src, n=sys.maxsize):
rng = random.Random(seed) rng = random.Random(seed)
print("# resampled loading", file=sys.stderr) print("# resampled loading", file=sys.stderr)
items = list(src) items = list(src)
print( print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
for i in range(n): for i in range(n):
yield rng.choice(items) yield rng.choice(items)
@ -124,9 +118,7 @@ def non_empty(src):
yield s yield s
count += 1 count += 1
if count == 0: if count == 0:
raise ValueError( raise ValueError("pipeline stage received no data at all and this was declared as an error")
"pipeline stage received no data at all and this was declared as an error"
)
@dataclass @dataclass
@ -146,6 +138,10 @@ def expand(s):
return os.path.expanduser(os.path.expandvars(s)) return os.path.expanduser(os.path.expandvars(s))
class MultiShardSample(IterableDataset):
def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec."""
self.epoch = -1
class MultiShardSample(IterableDataset): class MultiShardSample(IterableDataset):
def __init__(self, fname): def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec.""" """Construct a shardlist from multiple sources using a YAML spec."""
@ -160,23 +156,20 @@ class MultiShardSample(IterableDataset):
else: else:
with open(fname) as stream: with open(fname) as stream:
spec = yaml.safe_load(stream) spec = yaml.safe_load(stream)
assert set(spec.keys()).issubset( assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys())
set("prefix datasets buckets".split())), list(spec.keys())
prefix = expand(spec.get("prefix", "")) prefix = expand(spec.get("prefix", ""))
self.sources = [] self.sources = []
for ds in spec["datasets"]: for ds in spec["datasets"]:
assert set(ds.keys()).issubset( assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list(
set("buckets name shards resample choose".split())), list( ds.keys()
ds.keys()) )
buckets = ds.get("buckets", spec.get("buckets", [])) buckets = ds.get("buckets", spec.get("buckets", []))
if isinstance(buckets, str): if isinstance(buckets, str):
buckets = [buckets] buckets = [buckets]
buckets = [expand(s) for s in buckets] buckets = [expand(s) for s in buckets]
if buckets == []: if buckets == []:
buckets = [""] buckets = [""]
assert len( assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
buckets
) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
bucket = buckets[0] bucket = buckets[0]
name = ds.get("name", "@" + bucket) name = ds.get("name", "@" + bucket)
urls = ds["shards"] urls = ds["shards"]
@ -184,19 +177,15 @@ class MultiShardSample(IterableDataset):
urls = [urls] urls = [urls]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)] # urls = [u for url in urls for u in braceexpand.braceexpand(url)]
urls = [ urls = [
prefix + os.path.join(bucket, u) prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url))
for url in urls for u in braceexpand.braceexpand(expand(url))
] ]
resample = ds.get("resample", -1) resample = ds.get("resample", -1)
nsample = ds.get("choose", -1) nsample = ds.get("choose", -1)
if nsample > len(urls): if nsample > len(urls):
raise ValueError( raise ValueError(f"perepoch {nsample} must be no greater than the number of shards")
f"perepoch {nsample} must be no greater than the number of shards"
)
if (nsample > 0) and (resample > 0): if (nsample > 0) and (resample > 0):
raise ValueError("specify only one of perepoch or choose") raise ValueError("specify only one of perepoch or choose")
entry = MSSource( entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample)
name=name, urls=urls, perepoch=nsample, resample=resample)
self.sources.append(entry) self.sources.append(entry)
print(f"# {name} {len(urls)} {nsample}", file=sys.stderr) print(f"# {name} {len(urls)} {nsample}", file=sys.stderr)
@ -214,7 +203,7 @@ class MultiShardSample(IterableDataset):
# sample without replacement # sample without replacement
l = list(source.urls) l = list(source.urls)
self.rng.shuffle(l) self.rng.shuffle(l)
l = l[:source.perepoch] l = l[: source.perepoch]
else: else:
l = list(source.urls) l = list(source.urls)
result += l result += l
@ -242,7 +231,8 @@ class ResampledShards(IterableDataset):
urls, urls,
nshards=sys.maxsize, nshards=sys.maxsize,
worker_seed=None, worker_seed=None,
deterministic=False, ): deterministic=False,
):
"""Sample shards from the shard list with replacement. """Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string :param urls: a list of URLs as a Python list or brace notation string
@ -262,8 +252,7 @@ class ResampledShards(IterableDataset):
if self.deterministic: if self.deterministic:
seed = utils.make_seed(self.worker_seed(), self.epoch) seed = utils.make_seed(self.worker_seed(), self.epoch)
else: else:
seed = utils.make_seed(self.worker_seed(), self.epoch, seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4))
os.getpid(), time.time_ns(), os.urandom(4))
if os.environ.get("WDS_SHOW_SEED", "0") == "1": if os.environ.get("WDS_SHOW_SEED", "0") == "1":
print(f"# ResampledShards seed {seed}") print(f"# ResampledShards seed {seed}")
self.rng = random.Random(seed) self.rng = random.Random(seed)

@ -3,12 +3,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library. # This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Low level iteration functions for tar archives.""" """Low level iteration functions for tar archives."""
import random
import re import random, re, tarfile
import tarfile
import braceexpand import braceexpand
@ -26,7 +27,6 @@ import numpy as np
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
def base_plus_ext(path): def base_plus_ext(path):
"""Split off all file extensions. """Split off all file extensions.
@ -47,8 +47,12 @@ def valid_sample(sample):
:param sample: sample to be checked :param sample: sample to be checked
""" """
return (sample is not None and isinstance(sample, dict) and return (
len(list(sample.keys())) > 0 and not sample.get("__bad__", False)) sample is not None
and isinstance(sample, dict)
and len(list(sample.keys())) > 0
and not sample.get("__bad__", False)
)
# FIXME: UNUSED # FIXME: UNUSED
@ -75,16 +79,16 @@ def url_opener(data, handler=reraise_exception, **kw):
sample.update(stream=stream) sample.update(stream=stream)
yield sample yield sample
except Exception as exn: except Exception as exn:
exn.args = exn.args + (url, ) exn.args = exn.args + (url,)
if handler(exn): if handler(exn):
continue continue
else: else:
break break
def tar_file_iterator(fileobj, def tar_file_iterator(
skip_meta=r"__[^/]*__($|/)", fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
handler=reraise_exception): ):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream. """Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile :param fileobj: byte stream suitable for tarfile
@ -99,8 +103,11 @@ def tar_file_iterator(fileobj,
continue continue
if fname is None: if fname is None:
continue continue
if ("/" not in fname and fname.startswith(meta_prefix) and if (
fname.endswith(meta_suffix)): "/" not in fname
and fname.startswith(meta_prefix)
and fname.endswith(meta_suffix)
):
# skipping metadata for now # skipping metadata for now
continue continue
if skip_meta is not None and re.match(skip_meta, fname): if skip_meta is not None and re.match(skip_meta, fname):
@ -111,10 +118,8 @@ def tar_file_iterator(fileobj,
assert pos > 0 assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:] prefix, postfix = name[:pos], name[pos + 1:]
if postfix == 'wav': if postfix == 'wav':
waveform, sample_rate = paddlespeech.audio.load( waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False)
stream.extractfile(tarinfo), normal=False) result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
result = dict(
fname=prefix, wav=waveform, sample_rate=sample_rate)
else: else:
txt = stream.extractfile(tarinfo).read().decode('utf8').strip() txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
result = dict(fname=prefix, txt=txt) result = dict(fname=prefix, txt=txt)
@ -123,17 +128,16 @@ def tar_file_iterator(fileobj,
stream.members = [] stream.members = []
except Exception as exn: except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0: if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj), ) + exn.args[1:] exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
if handler(exn): if handler(exn):
continue continue
else: else:
break break
del stream del stream
def tar_file_and_group_iterator(
def tar_file_and_group_iterator(fileobj, fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
skip_meta=r"__[^/]*__($|/)", ):
handler=reraise_exception):
""" Expand a stream of open tar files into a stream of tar file contents. """ Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix And groups the file with same prefix
@ -163,11 +167,8 @@ def tar_file_and_group_iterator(fileobj,
if postfix == 'txt': if postfix == 'txt':
example['txt'] = file_obj.read().decode('utf8').strip() example['txt'] = file_obj.read().decode('utf8').strip()
elif postfix in AUDIO_FORMAT_SETS: elif postfix in AUDIO_FORMAT_SETS:
waveform, sample_rate = paddlespeech.audio.load( waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False)
file_obj, normal=False) waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
waveform = paddle.to_tensor(
np.expand_dims(np.array(waveform), 0),
dtype=paddle.float32)
example['wav'] = waveform example['wav'] = waveform
example['sample_rate'] = sample_rate example['sample_rate'] = sample_rate
@ -175,8 +176,7 @@ def tar_file_and_group_iterator(fileobj,
example[postfix] = file_obj.read() example[postfix] = file_obj.read()
except Exception as exn: except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0: if hasattr(exn, "args") and len(exn.args) > 0:
exn.args = (exn.args[0] + " @ " + str(fileobj), exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
) + exn.args[1:]
if handler(exn): if handler(exn):
continue continue
else: else:
@ -189,7 +189,6 @@ def tar_file_and_group_iterator(fileobj,
yield example yield example
stream.close() stream.close()
def tar_file_expander(data, handler=reraise_exception): def tar_file_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents. """Expand a stream of open tar files into a stream of tar file contents.
@ -201,8 +200,9 @@ def tar_file_expander(data, handler=reraise_exception):
assert isinstance(source, dict) assert isinstance(source, dict)
assert "stream" in source assert "stream" in source
for sample in tar_file_iterator(source["stream"]): for sample in tar_file_iterator(source["stream"]):
assert (isinstance(sample, dict) and "data" in sample and assert (
"fname" in sample) isinstance(sample, dict) and "data" in sample and "fname" in sample
)
sample["__url__"] = url sample["__url__"] = url
yield sample yield sample
except Exception as exn: except Exception as exn:
@ -213,6 +213,8 @@ def tar_file_expander(data, handler=reraise_exception):
break break
def tar_file_and_group_expander(data, handler=reraise_exception): def tar_file_and_group_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents. """Expand a stream of open tar files into a stream of tar file contents.
@ -224,8 +226,9 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
assert isinstance(source, dict) assert isinstance(source, dict)
assert "stream" in source assert "stream" in source
for sample in tar_file_and_group_iterator(source["stream"]): for sample in tar_file_and_group_iterator(source["stream"]):
assert (isinstance(sample, dict) and "wav" in sample and assert (
"txt" in sample and "fname" in sample) isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample
)
sample["__url__"] = url sample["__url__"] = url
yield sample yield sample
except Exception as exn: except Exception as exn:
@ -236,11 +239,7 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
break break
def group_by_keys(data, def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
keys=base_plus_ext,
lcase=True,
suffixes=None,
handler=None):
"""Return function over iterator that groups key, value pairs into samples. """Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext) :param keys: function that splits the key into key and extension (base_plus_ext)
@ -255,8 +254,8 @@ def group_by_keys(data,
print( print(
prefix, prefix,
suffix, suffix,
current_sample.keys() current_sample.keys() if isinstance(current_sample, dict) else None,
if isinstance(current_sample, dict) else None, ) )
if prefix is None: if prefix is None:
continue continue
if lcase: if lcase:

@ -4,23 +4,22 @@
# This file is part of the WebDataset library. # This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# #
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
"""Miscellaneous utility functions.""" """Miscellaneous utility functions."""
import importlib import importlib
import itertools as itt import itertools as itt
import os import os
import re import re
import sys import sys
from typing import Any from typing import Any, Callable, Iterator, Optional, Union
from typing import Callable
from typing import Iterator
from typing import Union
from ..utils.log import Logger from ..utils.log import Logger
logger = Logger(__name__) logger = Logger(__name__)
def make_seed(*args): def make_seed(*args):
seed = 0 seed = 0
for arg in args: for arg in args:
@ -38,7 +37,7 @@ def identity(x: Any) -> Any:
return x return x
def safe_eval(s: str, expr: str="{}"): def safe_eval(s: str, expr: str = "{}"):
"""Evaluate the given expression more safely.""" """Evaluate the given expression more safely."""
if re.sub("[^A-Za-z0-9_]", "", s) != s: if re.sub("[^A-Za-z0-9_]", "", s) != s:
raise ValueError(f"safe_eval: illegal characters in: '{s}'") raise ValueError(f"safe_eval: illegal characters in: '{s}'")
@ -55,9 +54,9 @@ def lookup_sym(sym: str, modules: list):
return None return None
def repeatedly0(loader: Iterator, def repeatedly0(
nepochs: int=sys.maxsize, loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
nbatches: int=sys.maxsize): ):
"""Repeatedly returns batches from a DataLoader.""" """Repeatedly returns batches from a DataLoader."""
for epoch in range(nepochs): for epoch in range(nepochs):
for sample in itt.islice(loader, nbatches): for sample in itt.islice(loader, nbatches):
@ -71,10 +70,11 @@ def guess_batchsize(batch: Union[tuple, list]):
def repeatedly( def repeatedly(
source: Iterator, source: Iterator,
nepochs: int=None, nepochs: int = None,
nbatches: int=None, nbatches: int = None,
nsamples: int=None, nsamples: int = None,
batchsize: Callable[..., int]=guess_batchsize, ): batchsize: Callable[..., int] = guess_batchsize,
):
"""Repeatedly yield samples from an iterator.""" """Repeatedly yield samples from an iterator."""
epoch = 0 epoch = 0
batch = 0 batch = 0
@ -93,7 +93,6 @@ def repeatedly(
if nepochs is not None and epoch >= nepochs: if nepochs is not None and epoch >= nepochs:
return return
def paddle_worker_info(group=None): def paddle_worker_info(group=None):
"""Return node and worker info for PyTorch and some distributed environments.""" """Return node and worker info for PyTorch and some distributed environments."""
rank = 0 rank = 0
@ -117,7 +116,7 @@ def paddle_worker_info(group=None):
else: else:
try: try:
from paddle.io import get_worker_info from paddle.io import get_worker_info
worker_info = get_worker_info() worker_info = paddle.io.get_worker_info()
if worker_info is not None: if worker_info is not None:
worker = worker_info.id worker = worker_info.id
num_workers = worker_info.num_workers num_workers = worker_info.num_workers
@ -127,7 +126,6 @@ def paddle_worker_info(group=None):
return rank, world_size, worker, num_workers return rank, world_size, worker, num_workers
def paddle_worker_seed(group=None): def paddle_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node.""" """Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = paddle_worker_info(group=group) rank, world_size, worker, num_workers = paddle_worker_info(group=group)

@ -5,24 +5,18 @@
# See the LICENSE file for licensing terms (BSD-style). # See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset # Modified from https://github.com/webdataset/webdataset
# #
"""Classes and functions for writing tar files and WebDataset files.""" """Classes and functions for writing tar files and WebDataset files."""
import io
import json import io, json, pickle, re, tarfile, time
import pickle from typing import Any, Callable, Optional, Union
import re
import tarfile
import time
from typing import Any
from typing import Callable
from typing import Optional
from typing import Union
import numpy as np import numpy as np
from . import gopen from . import gopen
def imageencoder(image: Any, format: str="PNG"): # skipcq: PYL-W0622 def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string. """Compress an image using PIL and return it as a string.
Can handle float or uint8 images. Can handle float or uint8 images.
@ -73,7 +67,6 @@ def bytestr(data: Any):
return data.encode("ascii") return data.encode("ascii")
return str(data).encode("ascii") return str(data).encode("ascii")
def paddle_dumps(data: Any): def paddle_dumps(data: Any):
"""Dump data into a bytestring using paddle.dumps. """Dump data into a bytestring using paddle.dumps.
@ -89,7 +82,6 @@ def paddle_dumps(data: Any):
paddle.save(data, stream) paddle.save(data, stream)
return stream.getvalue() return stream.getvalue()
def numpy_dumps(data: np.ndarray): def numpy_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npy format. """Dump data into a bytestring using numpy npy format.
@ -147,8 +139,9 @@ def add_handlers(d, keys, value):
def make_handlers(): def make_handlers():
"""Create a list of handlers for encoding data.""" """Create a list of handlers for encoding data."""
handlers = {} handlers = {}
add_handlers(handlers, "cls cls2 class count index inx id", add_handlers(
lambda x: str(x).encode("ascii")) handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii")
)
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8")) add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8")) add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
add_handlers(handlers, "pyd pickle", pickle.dumps) add_handlers(handlers, "pyd pickle", pickle.dumps)
@ -159,8 +152,7 @@ def make_handlers():
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8")) add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
add_handlers(handlers, "mp msgpack msg", mp_dumps) add_handlers(handlers, "mp msgpack msg", mp_dumps)
add_handlers(handlers, "cbor", cbor_dumps) add_handlers(handlers, "cbor", cbor_dumps)
add_handlers(handlers, "jpg jpeg img image", add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg"))
lambda data: imageencoder(data, "jpg"))
add_handlers(handlers, "png", lambda data: imageencoder(data, "png")) add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm")) add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm")) add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
@ -200,8 +192,7 @@ def encode_based_on_extension(sample: dict, handlers: dict):
:param handlers: handlers for encoding :param handlers: handlers for encoding
""" """
return { return {
k: encode_based_on_extension1(v, k, handlers) k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())
for k, v in list(sample.items())
} }
@ -269,12 +260,13 @@ class TarWriter:
def __init__( def __init__(
self, self,
fileobj, fileobj,
user: str="bigdata", user: str = "bigdata",
group: str="bigdata", group: str = "bigdata",
mode: int=0o0444, mode: int = 0o0444,
compress: Optional[bool]=None, compress: Optional[bool] = None,
encoder: Union[None, bool, Callable]=True, encoder: Union[None, bool, Callable] = True,
keep_meta: bool=False, ): keep_meta: bool = False,
):
"""Create a tar writer. """Create a tar writer.
:param fileobj: stream to write data to :param fileobj: stream to write data to
@ -338,7 +330,8 @@ class TarWriter:
continue continue
if not isinstance(v, (bytes, bytearray, memoryview)): if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError( raise ValueError(
f"{k} doesn't map to a bytes after encoding ({type(v)})") f"{k} doesn't map to a bytes after encoding ({type(v)})"
)
key = obj["__key__"] key = obj["__key__"]
for k in sorted(obj.keys()): for k in sorted(obj.keys()):
if k == "__key__": if k == "__key__":
@ -356,8 +349,7 @@ class TarWriter:
ti.uname = self.user ti.uname = self.user
ti.gname = self.group ti.gname = self.group
if not isinstance(v, (bytes, bytearray, memoryview)): if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError( raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
f"converter didn't yield bytes: {k}, {type(v)}")
stream = io.BytesIO(v) stream = io.BytesIO(v)
self.tarstream.addfile(ti, stream) self.tarstream.addfile(ti, stream)
total += ti.size total += ti.size
@ -370,11 +362,12 @@ class ShardWriter:
def __init__( def __init__(
self, self,
pattern: str, pattern: str,
maxcount: int=100000, maxcount: int = 100000,
maxsize: float=3e9, maxsize: float = 3e9,
post: Optional[Callable]=None, post: Optional[Callable] = None,
start_shard: int=0, start_shard: int = 0,
**kw, ): **kw,
):
"""Create a ShardWriter. """Create a ShardWriter.
:param pattern: output file pattern :param pattern: output file pattern
@ -407,7 +400,8 @@ class ShardWriter:
self.fname, self.fname,
self.count, self.count,
"%.1f GB" % (self.size / 1e9), "%.1f GB" % (self.size / 1e9),
self.total, ) self.total,
)
self.shard += 1 self.shard += 1
stream = open(self.fname, "wb") stream = open(self.fname, "wb")
self.tarstream = TarWriter(stream, **self.kw) self.tarstream = TarWriter(stream, **self.kw)
@ -419,8 +413,11 @@ class ShardWriter:
:param obj: sample to be written :param obj: sample to be written
""" """
if (self.tarstream is None or self.count >= self.maxcount or if (
self.size >= self.maxsize): self.tarstream is None
or self.count >= self.maxcount
or self.size >= self.maxsize
):
self.next_stream() self.next_stream()
size = self.tarstream.write(obj) size = self.tarstream.write(obj)
self.count += 1 self.count += 1

@ -17,7 +17,6 @@ from typing import Union
import sentencepiece as spm import sentencepiece as spm
from ..utils.log import Logger
from .utility import BLANK from .utility import BLANK
from .utility import EOS from .utility import EOS
from .utility import load_dict from .utility import load_dict
@ -25,6 +24,7 @@ from .utility import MASKCTC
from .utility import SOS from .utility import SOS
from .utility import SPACE from .utility import SPACE
from .utility import UNK from .utility import UNK
from ..utils.log import Logger
logger = Logger(__name__) logger = Logger(__name__)

@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
import io
import os
import h5py
import librosa import librosa
import numpy import numpy
import numpy as np
import scipy import scipy
import soundfile import soundfile
import io
import os
import h5py
import numpy as np
class SoundHDF5File(): class SoundHDF5File():
"""Collecting sound files to a HDF5 file """Collecting sound files to a HDF5 file
@ -110,7 +109,6 @@ class SoundHDF5File():
def close(self): def close(self):
self.file.close() self.file.close()
class SpeedPerturbation(): class SpeedPerturbation():
"""SpeedPerturbation """SpeedPerturbation
@ -560,3 +558,4 @@ class RIRConvolve():
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1) [scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
else: else:
return scipy.convolve(x, rir, mode="same") return scipy.convolve(x, rir, mode="same")

@ -14,7 +14,6 @@
# Modified from espnet(https://github.com/espnet/espnet) # Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation""" """Spec Augment module for preprocessing i.e., data augmentation"""
import random import random
import numpy import numpy
from PIL import Image from PIL import Image

@ -381,6 +381,36 @@ class LogMelSpectrogramKaldi():
mat = np.squeeze(mat.numpy()) mat = np.squeeze(mat.numpy())
return mat return mat
class WavProcess():
def __init__(
self,
dither=0.1):
"""
Args:
dither (float): Dithering constant
Returns:
"""
self.dither = dither
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
waveform = np.expand_dims(x, -1)
return waveform
class LogMelSpectrogramKaldi_decay(): class LogMelSpectrogramKaldi_decay():
def __init__( def __init__(

@ -41,6 +41,7 @@ import_alias = dict(
utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN", utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN",
fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram", fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram",
spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram", spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram",
wav_process="paddlespeech.audio.transform.spectrogram:WavProcess",
stft="paddlespeech.audio.transform.spectrogram:Stft", stft="paddlespeech.audio.transform.spectrogram:Stft",
istft="paddlespeech.audio.transform.spectrogram:IStft", istft="paddlespeech.audio.transform.spectrogram:IStft",
stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram", stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram",

@ -99,9 +99,8 @@ class ASRExecutor(BaseExecutor):
'-y', '-y',
action="store_true", action="store_true",
default=False, default=False,
help='No additional parameters required. \ help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate'
Once set this parameter, it means accepting the request of the program by default, \ )
which includes transforming the audio sample rate')
self.parser.add_argument( self.parser.add_argument(
'--rtf', '--rtf',
action="store_true", action="store_true",
@ -341,7 +340,7 @@ class ASRExecutor(BaseExecutor):
audio = np.round(audio).astype("int16") audio = np.round(audio).astype("int16")
return audio return audio
def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False): def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error( logger.error(
@ -435,17 +434,8 @@ class ASRExecutor(BaseExecutor):
for id_, input_ in task_source.items(): for id_, input_ in task_source.items():
try: try:
res = self( res = self(input_, model, lang, sample_rate, config, ckpt_path,
audio_file=input_, decode_method, force_yes, rtf, device)
model=model,
lang=lang,
sample_rate=sample_rate,
config=config,
ckpt_path=ckpt_path,
decode_method=decode_method,
force_yes=force_yes,
rtf=rtf,
device=device)
task_results[id_] = res task_results[id_] = res
except Exception as e: except Exception as e:
has_exceptions = True has_exceptions = True

@ -70,14 +70,6 @@ class VectorExecutor(BaseExecutor):
type=str, type=str,
default=None, default=None,
help="Checkpoint file of model.") help="Checkpoint file of model.")
self.parser.add_argument(
'--yes',
'-y',
action="store_true",
default=False,
help='No additional parameters required. \
Once set this parameter, it means accepting the request of the program by default, \
which includes transforming the audio sample rate')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
type=str, type=str,
@ -117,7 +109,6 @@ class VectorExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate sample_rate = parser_args.sample_rate
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path ckpt_path = parser_args.ckpt_path
force_yes = parser_args.yes
device = parser_args.device device = parser_args.device
# stage 1: configurate the verbose flag # stage 1: configurate the verbose flag
@ -137,14 +128,8 @@ class VectorExecutor(BaseExecutor):
# extract the speaker audio embedding # extract the speaker audio embedding
if parser_args.task == "spk": if parser_args.task == "spk":
logger.debug("do vector spk task") logger.debug("do vector spk task")
res = self( res = self(input_, model, sample_rate, config, ckpt_path,
audio_file=input_, device)
model=model,
sample_rate=sample_rate,
config=config,
ckpt_path=ckpt_path,
force_yes=force_yes,
device=device)
task_result[id_] = res task_result[id_] = res
elif parser_args.task == "score": elif parser_args.task == "score":
logger.debug("do vector score task") logger.debug("do vector score task")
@ -160,22 +145,10 @@ class VectorExecutor(BaseExecutor):
logger.debug( logger.debug(
f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}" f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}"
) )
enroll_embedding = self( enroll_embedding = self(enroll_audio, model, sample_rate,
audio_file=enroll_audio, config, ckpt_path, device)
model=model, test_embedding = self(test_audio, model, sample_rate,
sample_rate=sample_rate, config, ckpt_path, device)
config=config,
ckpt_path=ckpt_path,
force_yes=force_yes,
device=device)
test_embedding = self(
audio_file=test_audio,
model=model,
sample_rate=sample_rate,
config=config,
ckpt_path=ckpt_path,
force_yes=force_yes,
device=device)
# get the score # get the score
res = self.get_embeddings_score(enroll_embedding, res = self.get_embeddings_score(enroll_embedding,
@ -249,7 +222,6 @@ class VectorExecutor(BaseExecutor):
sample_rate: int=16000, sample_rate: int=16000,
config: os.PathLike=None, config: os.PathLike=None,
ckpt_path: os.PathLike=None, ckpt_path: os.PathLike=None,
force_yes: bool=False,
device=paddle.get_device()): device=paddle.get_device()):
"""Extract the audio embedding """Extract the audio embedding
@ -268,7 +240,7 @@ class VectorExecutor(BaseExecutor):
""" """
# stage 0: check the audio format # stage 0: check the audio format
audio_file = os.path.abspath(audio_file) audio_file = os.path.abspath(audio_file)
if not self._check(audio_file, sample_rate, force_yes): if not self._check(audio_file, sample_rate):
sys.exit(-1) sys.exit(-1)
# stage 1: set the paddle runtime host device # stage 1: set the paddle runtime host device
@ -446,7 +418,7 @@ class VectorExecutor(BaseExecutor):
logger.debug("audio extract the feat success") logger.debug("audio extract the feat success")
def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False): def _check(self, audio_file: str, sample_rate: int):
"""Check if the model sample match the audio sample rate """Check if the model sample match the audio sample rate
Args: Args:
@ -490,34 +462,13 @@ class VectorExecutor(BaseExecutor):
logger.debug(f"The sample rate is {audio_sample_rate}") logger.debug(f"The sample rate is {audio_sample_rate}")
if audio_sample_rate != self.sample_rate: if audio_sample_rate != self.sample_rate:
logger.debug("The sample rate of the input file is not {}.\n \ logger.error("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \ If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \ Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate)) ".format(self.sample_rate, self.sample_rate))
if force_yes is False: sys.exit(-1)
while (True):
logger.debug(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip(
) == "Yes":
logger.debug(
"change the sampele rate, channel to 16k and 1 channel"
)
break
elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip(
) == "No":
logger.debug("Exit the program")
return False
else:
logger.warning("Not regular input, please input again")
self.change_format = True
else: else:
logger.debug("The audio file format is right") logger.debug("The audio file format is right")
self.change_format = False
return True return True

@ -1363,11 +1363,5 @@ g2pw_onnx_models = {
'md5': 'md5':
'7e049a55547da840502cf99e8a64f20e', '7e049a55547da840502cf99e8a64f20e',
}, },
'1.1': {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip',
'md5':
'f8b60501770bff92ed6ce90860a610e6',
},
}, },
} }

@ -114,7 +114,6 @@ if not hasattr(paddle.Tensor, 'new_full'):
paddle.Tensor.new_full = new_full paddle.Tensor.new_full = new_full
paddle.static.Variable.new_full = new_full paddle.static.Variable.new_full = new_full
def contiguous(xs: paddle.Tensor) -> paddle.Tensor: def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
return xs return xs

@ -20,8 +20,8 @@ import paddle
import soundfile import soundfile
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model from paddlespeech.s2t.models.ds2 import DeepSpeech2Model
from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils import mp_tools
@ -38,24 +38,24 @@ class DeepSpeech2Tester_hub():
self.args = args self.args = args
self.config = config self.config = config
self.audio_file = args.audio_file self.audio_file = args.audio_file
self.collate_fn_test = SpeechCollator.from_config(config)
self.preprocess_conf = config.preprocess_config self._text_featurizer = TextFeaturizer(
self.preprocess_args = {"train": False} unit_type=config.unit_type, vocab=None)
self.preprocessing = Transformation(self.preprocess_conf)
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath,
spm_model_prefix=config.spm_model_prefix)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
decode_batch_size = cfg.decode_batch_size result_transcripts = self.model.decode(
self.model.decoder.init_decoder( audio,
decode_batch_size, vocab_list, cfg.decoding_method, audio_len,
cfg.lang_model_path, cfg.alpha, cfg.beta, cfg.beam_size, vocab_list,
cfg.cutoff_prob, cfg.cutoff_top_n, cfg.num_proc_bsearch) decoding_method=cfg.decoding_method,
result_transcripts = self.model.decode(audio, audio_len) lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
return result_transcripts return result_transcripts
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@ -64,23 +64,16 @@ class DeepSpeech2Tester_hub():
self.model.eval() self.model.eval()
cfg = self.config cfg = self.config
audio_file = self.audio_file audio_file = self.audio_file
collate_fn_test = self.collate_fn_test
audio, sample_rate = soundfile.read( audio, _ = collate_fn_test.process_utterance(
self.audio_file, dtype="int16", always_2d=True) audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0]
audio = audio[:, 0] audio = paddle.to_tensor(audio, dtype='float32')
logger.info(f"audio shape: {audio.shape}") audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
# fbank vocab_list = collate_fn_test.vocab_list
feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}")
audio_len = paddle.to_tensor(feat.shape[0])
audio = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
result_transcripts = self.compute_result_transcripts( result_transcripts = self.compute_result_transcripts(
audio, audio_len, self.text_feature.vocab_list, cfg.decode) audio, audio_len, vocab_list, cfg.decode)
logger.info("result_transcripts: " + result_transcripts[0]) logger.info("result_transcripts: " + result_transcripts[0])
def run_test(self): def run_test(self):
@ -116,9 +109,11 @@ class DeepSpeech2Tester_hub():
def setup_model(self): def setup_model(self):
config = self.config.clone() config = self.config.clone()
with UpdateConfig(config): with UpdateConfig(config):
config.input_dim = config.feat_dim config.input_dim = self.collate_fn_test.feature_size
config.output_dim = self.text_feature.vocab_size config.output_dim = self.collate_fn_test.vocab_size
model = DeepSpeech2Model.from_config(config) model = DeepSpeech2Model.from_config(config)
self.model = model self.model = model
def setup_checkpointer(self): def setup_checkpointer(self):

@ -25,6 +25,8 @@ import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
@ -107,8 +109,7 @@ class U2Trainer(Trainer):
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata: if not self.use_streamdata:
logger.info( logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
@ -135,8 +136,7 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
logger.info(msg) logger.info(msg)
@ -157,8 +157,7 @@ class U2Trainer(Trainer):
self.before_train() self.before_train()
if not self.use_streamdata: if not self.use_streamdata:
logger.info( logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
@ -226,18 +225,14 @@ class U2Trainer(Trainer):
config = self.config.clone() config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False) self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
self.train_loader = DataLoaderFactory.get_dataloader( self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
'train', config, self.args) self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
decode_batch_size = config.get('decode', dict()).get( decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1) 'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.args) self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!") logger.info("Setup test/align Dataloader!")
def setup_model(self): def setup_model(self):
@ -250,8 +245,7 @@ class U2Trainer(Trainer):
model_conf.output_dim = self.train_loader.vocab_size model_conf.output_dim = self.train_loader.vocab_size
else: else:
model_conf.input_dim = self.test_loader.feat_dim model_conf.input_dim = self.test_loader.feat_dim
model_conf.output_dim = self.test_loader.vocab_size model_conf.output_dim = 5538
model = U2Model.from_config(model_conf) model = U2Model.from_config(model_conf)
if self.parallel: if self.parallel:
@ -316,6 +310,11 @@ class U2Tester(U2Trainer):
unit_type=self.config.unit_type, unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath, vocab=self.config.vocab_filepath,
spm_model_prefix=self.config.spm_model_prefix) spm_model_prefix=self.config.spm_model_prefix)
self.text_feature_test = TextFeaturizer(
unit_type=self.config.unit_type,
vocab='/home/zhangtianhao/workspace/PaddleSpeech/examples/aishell/asr1/data/lang_char/vocab.txt',
spm_model_prefix=self.config.spm_model_prefix)
self.vocab_list = self.text_feature.vocab_list self.vocab_list = self.text_feature.vocab_list
def id2token(self, texts, texts_len, text_feature): def id2token(self, texts, texts_len, text_feature):
@ -340,7 +339,7 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if decode_config.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if decode_config.error_rate_type == 'cer' else error_rate.wer
start_time = time.time() start_time = time.time()
target_transcripts = self.id2token(texts, texts_len, self.text_feature) target_transcripts = self.id2token(texts, texts_len, self.text_feature_test)
result_transcripts, result_tokenids = self.model.decode( result_transcripts, result_tokenids = self.model.decode(
audio, audio,
audio_len, audio_len,

@ -105,8 +105,7 @@ class U2Trainer(Trainer):
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata: if not self.use_streamdata:
logger.info( logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
@ -134,8 +133,7 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
logger.info(msg) logger.info(msg)
@ -155,8 +153,7 @@ class U2Trainer(Trainer):
self.before_train() self.before_train()
if not self.use_streamdata: if not self.use_streamdata:
logger.info( logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
@ -168,8 +165,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch : {}/{}, ".format( msg += "batch : {}/{}, ".format(batch_index + 1,
batch_index + 1, len(self.train_loader)) len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time) msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
@ -207,24 +204,21 @@ class U2Trainer(Trainer):
self.use_streamdata = config.get("use_stream_data", False) self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
config = self.config.clone() config = self.config.clone()
self.train_loader = DataLoaderFactory.get_dataloader( self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
'train', config, self.args)
config = self.config.clone() config = self.config.clone()
config['preprocess_config'] = None config['preprocess_config'] = None
self.valid_loader = DataLoaderFactory.get_dataloader( self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
config = self.config.clone() config = self.config.clone()
config['preprocess_config'] = None config['preprocess_config'] = None
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.args)
config = self.config.clone() config = self.config.clone()
config['preprocess_config'] = None config['preprocess_config'] = None
self.align_loader = DataLoaderFactory.get_dataloader( self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
'align', config, self.args)
logger.info("Setup test/align Dataloader!") logger.info("Setup test/align Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config

@ -121,8 +121,7 @@ class U2STTrainer(Trainer):
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata: if not self.use_streamdata:
logger.info( logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
@ -156,8 +155,7 @@ class U2STTrainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
logger.info(msg) logger.info(msg)
@ -177,8 +175,7 @@ class U2STTrainer(Trainer):
self.before_train() self.before_train()
if not self.use_streamdata: if not self.use_streamdata:
logger.info( logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
@ -251,16 +248,14 @@ class U2STTrainer(Trainer):
config['load_transcript'] = load_transcript config['load_transcript'] = load_transcript
self.use_streamdata = config.get("use_stream_data", False) self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
self.train_loader = DataLoaderFactory.get_dataloader( self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
'train', config, self.args) self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.args)
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
model_conf = config model_conf = config

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,66 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for U2 model."""
import cProfile
from yacs.config import CfgNode
from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
# TODO(hui zhang): dynamic load
def main_sp(config, args):
exp = Tester(config, args)
with exp.eval():
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.')
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats('test.profile')

@ -0,0 +1,55 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for U2 model."""
import cProfile
import os
from yacs.config import CfgNode
from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTrainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
def main_sp(config, args):
exp = Trainer(config, args)
exp.setup()
exp.run()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join(args.output, 'train.profile'))

@ -0,0 +1,465 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains wav2vec2 model."""
import json
import os
import time
from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from paddlespeech.s2t.utils import mp_tools
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.utils import error_rate
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
from paddlespeech.s2t.training.reporter import report
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
from paddlespeech.s2t.training.timer import Timer
from paddlespeech.s2t.training.trainer import Trainer
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.models.wav2vec2.speechbrain.processing.speech_augmentation import TimeDomainSpecAugment
import pdb
logger = Log(__name__).getlog()
class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def train_batch(self, batch_index, batch, msg):
train_conf = self.config
start = time.time()
# forward
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0]
if train_conf.augment:
wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
# print(self.model.wav2vec2.feature_projection.projection.weight)
# print(self.model.wav2vec2.feature_extractor.conv_layers[0].conv.weight)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
losses_np = {'loss': float(loss) * train_conf.accum_grad}
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
# When using cpu w/o DDP, model does not have `no_sync`
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step old
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
# optimizer step new
# if (batch_index + 1) % train_conf.accum_grad == 0:
# self.optimizer.step()
# self.optimizer.clear_grad()
# self.iteration += 1
iteration_time = time.time() - start
for k, v in losses_np.items():
report(k, v)
report("batch_size", self.config.batch_size)
report("accum", train_conf.accum_grad)
report("step_cost", iteration_time)
if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@paddle.no_grad()
def valid(self):
self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0]
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
if (i + 1) % self.config.log_interval == 0:
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
valid_dump['val_history_loss'] = total_loss / num_seen_utts
# logging
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
logger.info('Rank {} Val info val_loss {}'.format(
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def do_train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
# script_model = paddle.jit.to_static(self.model)
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
try:
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train:"
observation = OrderedDict()
with ObsScope(observation):
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
if not self.use_streamdata:
report('total', len(self.train_loader))
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
observation['samples'] = observation['batch_size']
observation['ips,samples/s'] = observation[
'batch_size'] / observation['batch_cost']
for k, v in observation.items():
msg += f" {k.split(',')[0]}: "
msg += f"{v:>.8f}" if isinstance(v,
float) else f"{v}"
msg += f" {k.split(',')[1]}" if len(
k.split(',')) == 2 else ""
msg += ","
msg = msg[:-1] # remove the last ","
if (batch_index + 1) % self.config.log_interval == 0:
logger.info(msg)
data_start_time = time.time()
except Exception as e:
logger.error(e)
raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
def setup_dataloader(self):
config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
config = self.config
model_conf = config
with UpdateConfig(model_conf):
if self.train:
model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
else:
model_conf.input_dim = self.test_loader.feat_dim
model_conf.output_dim = self.test_loader.vocab_size
model = Wav2vec2ASR.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model)
# logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
self.model = model
logger.info("Setup model!")
if model_conf.augment:
self.speech_augmentation = TimeDomainSpecAugment(sample_rate=16000, speeds=[95, 100, 105])
if not self.train:
return
train_config = config
optim_type = train_config.model_optim
optim_conf = train_config.model_optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
scheduler_args = {
"learning_rate": optim_conf.lr,
"verbose": False,
"warmup_steps": scheduler_conf.warmup_steps,
"gamma": scheduler_conf.lr_decay,
"d_model": model_conf.dnn_neurons,
}
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args)
def optimizer_args(
config,
parameters,
lr_scheduler=None, ):
train_config = config
optim_type = train_config.model_optim
optim_conf = train_config.model_optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
"epsilon": optim_conf.epsilon,
"rho": optim_conf.rho,
"parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}
# optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup optimizer/lr_scheduler!")
class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
def __init__(self, config, args):
super().__init__(config, args)
print(config)
self.text_featurizer = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath)
self.vocab_list = self.text_featurizer.vocab_list
def id2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(
self.text_featurizer.defeaturize(ids.numpy().tolist()))
return trans
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
decode_cfg = self.config.decode
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
target_transcripts = self.id2token(texts, texts_len)
result_transcripts, result_tokenids = self.model.decode(
audio,
audio_len,
text_feature=self.text_featurizer,
decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size)
decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip(
utts, target_transcripts, result_transcripts, result_tokenids):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write({
"utt": utt,
"refs": [target],
"hyps": [result],
"hyps_tokenid": [rec_tids],
})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % (
decode_cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins, # num examples
error_rate=errors_sum / len_refs,
error_rate_type=decode_cfg.error_rate_type,
num_frames=audio_len.sum().numpy().item(),
decode_time=decode_time)
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
self.model.eval()
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.vocab_list
decode_batch_size = decode_cfg.decode_batch_size
# self.model.decoder.init_decoder(
# decode_batch_size, vocab_list, decode_cfg.decoding_method,
# decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
# decode_cfg.beam_size, decode_cfg.cutoff_prob,
# decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
rtf = num_time / (num_frames)
logger.info(
"RTF: %f, Error rate [%s] (%d/?) = %f" %
(rtf, error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
err_type_str = "{}".format(error_rate_type)
with open(err_meta_path, 'w') as f:
data = json.dumps({
"epoch":
self.epoch,
"step":
self.iteration,
"rtf":
rtf,
error_rate_type:
errors_sum / len_refs,
"dataset_hour": (num_frames) / 1000.0 / 3600.0,
"process_hour":
num_time / 1000.0 / 3600.0,
"num_examples":
num_ins,
"err_sum":
errors_sum,
"ref_len":
len_refs,
"decode_method":
self.config.decode.decoding_method,
})
f.write(data + '\n')
@paddle.no_grad()
def export(self):
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader, self.config, self.args.checkpoint_path)
infer_model.eval()
static_model = infer_model.export()
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)

@ -22,16 +22,17 @@ import paddle
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
import paddlespeech.audio.streamdata as streamdata
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.batchfy import make_batchset from paddlespeech.s2t.io.batchfy import make_batchset
from paddlespeech.s2t.io.converter import CustomConverter from paddlespeech.s2t.io.converter import CustomConverter
from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.dataset import TransformDataset
from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
import paddlespeech.audio.streamdata as streamdata
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from yacs.config import CfgNode
__all__ = ["BatchDataLoader", "StreamDataLoader"] __all__ = ["BatchDataLoader", "StreamDataLoader"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -60,7 +61,6 @@ def batch_collate(x):
""" """
return x[0] return x[0]
def read_preprocess_cfg(preprocess_conf_file): def read_preprocess_cfg(preprocess_conf_file):
augment_conf = dict() augment_conf = dict()
preprocess_cfg = CfgNode(new_allowed=True) preprocess_cfg = CfgNode(new_allowed=True)
@ -84,7 +84,6 @@ def read_preprocess_cfg(preprocess_conf_file):
augment_conf['t_replace_with_zero'] = process['replace_with_zero'] augment_conf['t_replace_with_zero'] = process['replace_with_zero']
return augment_conf return augment_conf
class StreamDataLoader(): class StreamDataLoader():
def __init__(self, def __init__(self,
manifest_file: str, manifest_file: str,
@ -132,14 +131,10 @@ class StreamDataLoader():
world_size = paddle.distributed.get_world_size() world_size = paddle.distributed.get_world_size()
except Exception as e: except Exception as e:
logger.warninig(e) logger.warninig(e)
logger.warninig( logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1")
"can not get world_size using paddle.distributed.get_world_size(), use world_size=1" assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...")
)
assert len(shardlist) >= world_size, \
"the length of shard list should >= number of gpus/xpus/..."
update_n_iter_processes = int( update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0))
max(min(len(shardlist) / world_size - 1, self.n_iter_processes), 0))
logger.info(f"update_n_iter_processes {update_n_iter_processes}") logger.info(f"update_n_iter_processes {update_n_iter_processes}")
if update_n_iter_processes != self.n_iter_processes: if update_n_iter_processes != self.n_iter_processes:
self.n_iter_processes = update_n_iter_processes self.n_iter_processes = update_n_iter_processes
@ -147,50 +142,44 @@ class StreamDataLoader():
if self.dist_sampler: if self.dist_sampler:
base_dataset = streamdata.DataPipeline( base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist), streamdata.split_by_node streamdata.SimpleShardList(shardlist),
if train_mode else streamdata.placeholder(), streamdata.split_by_node if train_mode else streamdata.placeholder(),
streamdata.split_by_worker, streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)) streamdata.tarfile_to_samples(streamdata.reraise_exception)
)
else: else:
base_dataset = streamdata.DataPipeline( base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist), streamdata.SimpleShardList(shardlist),
streamdata.split_by_worker, streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)) streamdata.tarfile_to_samples(streamdata.reraise_exception)
)
self.dataset = base_dataset.append_list( self.dataset = base_dataset.append_list(
streamdata.audio_tokenize(symbol_table), streamdata.audio_tokenize(symbol_table),
streamdata.audio_data_filter( streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out),
frame_shift=frame_shift,
max_length=maxlen_in,
min_length=minlen_in,
token_max_length=maxlen_out,
token_min_length=minlen_out),
streamdata.audio_resample(resample_rate=resample_rate), streamdata.audio_resample(resample_rate=resample_rate),
streamdata.audio_compute_fbank( streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
num_mel_bins=num_mel_bins, streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither),
streamdata.audio_spec_aug(**augment_conf)
if train_mode else streamdata.placeholder(
), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata.shuffle(shuffle_size), streamdata.shuffle(shuffle_size),
streamdata.sort(sort_size=sort_size), streamdata.sort(sort_size=sort_size),
streamdata.batched(batch_size), streamdata.batched(batch_size),
streamdata.audio_padding(), streamdata.audio_padding(),
streamdata.audio_cmvn(cmvn_file)) streamdata.audio_cmvn(cmvn_file)
)
if paddle.__version__ >= '2.3.2': if paddle.__version__ >= '2.3.2':
self.loader = streamdata.WebLoader( self.loader = streamdata.WebLoader(
self.dataset, self.dataset,
num_workers=self.n_iter_processes, num_workers=self.n_iter_processes,
prefetch_factor=self.prefetch_factor, prefetch_factor = self.prefetch_factor,
batch_size=None) batch_size=None
)
else: else:
self.loader = streamdata.WebLoader( self.loader = streamdata.WebLoader(
self.dataset, self.dataset,
num_workers=self.n_iter_processes, num_workers=self.n_iter_processes,
batch_size=None) batch_size=None
)
def __iter__(self): def __iter__(self):
return self.loader.__iter__() return self.loader.__iter__()
@ -199,9 +188,7 @@ class StreamDataLoader():
return self.__iter__() return self.__iter__()
def __len__(self): def __len__(self):
logger.info( logger.info("Stream dataloader does not support calculate the length of the dataset")
"Stream dataloader does not support calculate the length of the dataset"
)
return -1 return -1
@ -371,9 +358,7 @@ class DataLoaderFactory():
config['maxlen_out'] = float('inf') config['maxlen_out'] = float('inf')
config['dist_sampler'] = False config['dist_sampler'] = False
else: else:
raise KeyError( raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return StreamDataLoader( return StreamDataLoader(
manifest_file=config.manifest, manifest_file=config.manifest,
train_mode=config.train_mode, train_mode=config.train_mode,
@ -395,7 +380,8 @@ class DataLoaderFactory():
prefetch_factor=config.prefetch_factor, prefetch_factor=config.prefetch_factor,
dist_sampler=config.dist_sampler, dist_sampler=config.dist_sampler,
cmvn_file=config.cmvn_file, cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath, ) vocab_filepath=config.vocab_filepath,
)
else: else:
if mode == 'train': if mode == 'train':
config['manifest'] = config.train_manifest config['manifest'] = config.train_manifest
@ -441,9 +427,7 @@ class DataLoaderFactory():
config['dist_sampler'] = False config['dist_sampler'] = False
config['shortest_first'] = False config['shortest_first'] = False
else: else:
raise KeyError( raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return BatchDataLoader( return BatchDataLoader(
json_file=config.manifest, json_file=config.manifest,
@ -466,3 +450,4 @@ class DataLoaderFactory():
num_encs=config.num_encs, num_encs=config.num_encs,
dist_sampler=config.dist_sampler, dist_sampler=config.dist_sampler,
shortest_first=config.shortest_first) shortest_first=config.shortest_first)

@ -120,6 +120,7 @@ class LoadInputsAndTargets():
x = self._get_from_loader( x = self._get_from_loader(
filepath=inp["feat"], filepath=inp["feat"],
filetype=inp.get("filetype", "mat")) filetype=inp.get("filetype", "mat"))
x_feats_dict.setdefault(inp["name"], []).append(x) x_feats_dict.setdefault(inp["name"], []).append(x)
if self.load_output: if self.load_output:
@ -236,6 +237,7 @@ class LoadInputsAndTargets():
:return: :return:
:rtype: np.ndarray :rtype: np.ndarray
""" """
if filetype == "hdf5": if filetype == "hdf5":
# e.g. # e.g.
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL", # {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",

@ -271,7 +271,7 @@ class DeepSpeech2Model(nn.Layer):
enc_n_units=self.encoder.output_size, enc_n_units=self.encoder.output_size,
blank_id=blank_id, blank_id=blank_id,
dropout_rate=0.0, dropout_rate=0.0,
reduction=True, # sum reduction_type="sum", # sum
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=ctc_grad_norm_type) grad_norm_type=ctc_grad_norm_type)

@ -0,0 +1,20 @@
import paddle
import paddle.nn as nn
class Model(nn.Layer):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1024,1024)
def forward(self, x):
return self.linear(x)
model = Model()
x = paddle.uniform([100,1024], dtype='float32')
out = model(x)
loss = paddle.mean(out)
loss.backward()
clip = nn.ClipGradByGlobalNorm(clip_norm=1.0)
optim = paddle.optimizer.Adadelta(learning_rate=0.1, parameters=model.parameters(), grad_clip=clip)
optim.step()

@ -605,8 +605,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return """ Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk. output from time 0 to current chunk.
@ -864,7 +864,7 @@ class U2Model(U2DecodeModel):
enc_n_units=encoder.output_size(), enc_n_units=encoder.output_size(),
blank_id=0, blank_id=0,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
reduction=True, # sum reduction_type="sum", # sum
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=grad_norm_type) grad_norm_type=grad_norm_type)

@ -18,6 +18,7 @@ Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recogni
""" """
import time import time
from typing import Dict from typing import Dict
from typing import List
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
@ -25,8 +26,6 @@ import paddle
from paddle import jit from paddle import jit
from paddle import nn from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.modules.cmvn import GlobalCMVN from paddlespeech.s2t.modules.cmvn import GlobalCMVN
@ -39,6 +38,8 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ["U2STModel", "U2STInferModel"] __all__ = ["U2STModel", "U2STInferModel"]
@ -400,8 +401,8 @@ class U2STBaseModel(nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return """ Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk. output from time 0 to current chunk.
@ -434,8 +435,8 @@ class U2STBaseModel(nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache. same shape as the original cnn_cache.
""" """
return self.encoder.forward_chunk(xs, offset, required_cache_size, return self.encoder.forward_chunk(
att_cache, cnn_cache) xs, offset, required_cache_size, att_cache, cnn_cache)
# @jit.to_static # @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
@ -611,7 +612,7 @@ class U2STModel(U2STBaseModel):
enc_n_units=encoder.output_size(), enc_n_units=encoder.output_size(),
blank_id=0, blank_id=0,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
reduction=True, # sum reduction_type='sum', # sum
batch_average=True, # sum / batch_size batch_average=True, # sum / batch_size
grad_norm_type=grad_norm_type) grad_norm_type=grad_norm_type)

@ -0,0 +1,175 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from packaging import version
from paddle import Tensor, nn
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
class NewGELUActivation(nn.Layer):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0))))
class GELUActivation(nn.Layer):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
super().__init__()
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class FastGELUActivation(nn.Layer):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Layer):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return input * paddle.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Layer):
"""
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
https://arxiv.org/abs/2004.09602.
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
"""
def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
self.max = max
def forward(self, x: Tensor) -> Tensor:
return paddle.clip(gelu(x), self.min, self.max)
class SiLUActivation(nn.Layer):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
def __init__(self):
super().__init__()
self.act = nn.functional.silu
def _silu_python(self, input: Tensor) -> Tensor:
return input * paddle.sigmoid(input)
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class MishActivation(nn.Layer):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
def __init__(self):
super().__init__()
self.act = nn.functional.mish
def _mish_python(self, input: Tensor) -> Tensor:
return input * paddle.tanh(nn.functional.softplus(input))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class LinearActivation(nn.Layer):
"""
Applies the linear activation function, i.e. forwarding input directly to output.
"""
def forward(self, input: Tensor) -> Tensor:
return input
ACT2FN = {
"gelu": GELUActivation(),
"gelu_10": ClippedGELUActivation(-10, 10),
"gelu_fast": FastGELUActivation(),
"gelu_new": NewGELUActivation(),
"gelu_python": GELUActivation(use_gelu_python=True),
"linear": LinearActivation(),
"mish": MishActivation(),
"quick_gelu": QuickGELUActivation(),
"relu": nn.ReLU(),
"sigmoid": nn.Sigmoid(),
"silu": SiLUActivation(),
"swish": SiLUActivation(),
"tanh": nn.Tanh(),
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,359 @@
import os
import paddle
import speechbrain as sb
from speechbrain.processing.speech_augmentation import (
SpeedPerturb,
DropFreq,
DropChunk,
)
class TimeDomainSpecAugment(paddle.nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm.
This augmentation module implements three augmentations in
the time-domain.
1. Drop chunks of the audio (zero amplitude or white noise)
2. Drop frequency bands (with band-drop filters)
3. Speed peturbation (via resampling to slightly different rate)
Arguments
---------
perturb_prob : float from 0 to 1
The probability that a batch will have speed perturbation applied.
drop_freq_prob : float from 0 to 1
The probability that a batch will have frequencies dropped.
drop_chunk_prob : float from 0 to 1
The probability that a batch will have chunks dropped.
speeds : list of ints
A set of different speeds to use to perturb each batch.
See ``speechbrain.processing.speech_augmentation.SpeedPerturb``
sample_rate : int
Sampling rate of the input waveforms.
drop_freq_count_low : int
Lowest number of frequencies that could be dropped.
drop_freq_count_high : int
Highest number of frequencies that could be dropped.
drop_chunk_count_low : int
Lowest number of chunks that could be dropped.
drop_chunk_count_high : int
Highest number of chunks that could be dropped.
drop_chunk_length_low : int
Lowest length of chunks that could be dropped.
drop_chunk_length_high : int
Highest length of chunks that could be dropped.
drop_chunk_noise_factor : float
The noise factor used to scale the white noise inserted, relative to
the average amplitude of the utterance. Default 0 (no noise inserted).
Example
-------
>>> inputs = torch.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, torch.ones(10))
>>> feats.shape
torch.Size([10, 12800])
"""
def __init__(
self,
perturb_prob=1.0,
drop_freq_prob=1.0,
drop_chunk_prob=1.0,
speeds=[95, 100, 105],
sample_rate=16000,
drop_freq_count_low=0,
drop_freq_count_high=3,
drop_chunk_count_low=0,
drop_chunk_count_high=5,
drop_chunk_length_low=1000,
drop_chunk_length_high=2000,
drop_chunk_noise_factor=0,
):
super().__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds
)
self.drop_freq = DropFreq(
drop_prob=drop_freq_prob,
drop_count_low=drop_freq_count_low,
drop_count_high=drop_freq_count_high,
)
self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor,
)
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
Arguments
---------
waveforms : torch.Tensor
The waveforms to distort
"""
# Augmentation
with paddle.no_grad():
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return
class DropFreq(torch.nn.Module):
"""This class drops a random frequency from the signal.
The purpose of this class is to teach models to learn to rely on all parts
of the signal, not just a few frequency bands.
Arguments
---------
drop_freq_low : float
The low end of frequencies that can be dropped,
as a fraction of the sampling rate / 2.
drop_freq_high : float
The high end of frequencies that can be
dropped, as a fraction of the sampling rate / 2.
drop_count_low : int
The low end of number of frequencies that could be dropped.
drop_count_high : int
The high end of number of frequencies that could be dropped.
drop_width : float
The width of the frequency band to drop, as
a fraction of the sampling_rate / 2.
drop_prob : float
The probability that the batch of signals will have a frequency
dropped. By default, every batch has frequencies dropped.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> dropper = DropFreq()
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> dropped_signal = dropper(signal.unsqueeze(0))
"""
def __init__(
self,
drop_freq_low=1e-14,
drop_freq_high=1,
drop_count_low=1,
drop_count_high=2,
drop_width=0.05,
drop_prob=1,
):
super().__init__()
self.drop_freq_low = drop_freq_low
self.drop_freq_high = drop_freq_high
self.drop_count_low = drop_count_low
self.drop_count_high = drop_count_high
self.drop_width = drop_width
self.drop_prob = drop_prob
def forward(self, waveforms):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
"""
# Don't drop (return early) 1-`drop_prob` portion of the batches
dropped_waveform = waveforms.clone()
if torch.rand(1) > self.drop_prob:
return dropped_waveform
# Add channels dimension
if len(waveforms.shape) == 2:
dropped_waveform = dropped_waveform.unsqueeze(-1)
# Pick number of frequencies to drop
drop_count = torch.randint(
low=self.drop_count_low, high=self.drop_count_high + 1, size=(1,),
)
# Pick a frequency to drop
drop_range = self.drop_freq_high - self.drop_freq_low
drop_frequency = (
torch.rand(drop_count) * drop_range + self.drop_freq_low
)
# Filter parameters
filter_length = 101
pad = filter_length // 2
# Start with delta function
drop_filter = torch.zeros(1, filter_length, 1, device=waveforms.device)
drop_filter[0, pad, 0] = 1
# Subtract each frequency
for frequency in drop_frequency:
notch_kernel = notch_filter(
frequency, filter_length, self.drop_width,
).to(waveforms.device)
drop_filter = convolve1d(drop_filter, notch_kernel, pad)
# Apply filter
dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad)
# Remove channels dimension if added
return dropped_waveform.squeeze(-1)
class DropChunk(torch.nn.Module):
"""This class drops portions of the input signal.
Using `DropChunk` as an augmentation strategy helps a models learn to rely
on all parts of the signal, since it can't expect a given part to be
present.
Arguments
---------
drop_length_low : int
The low end of lengths for which to set the
signal to zero, in samples.
drop_length_high : int
The high end of lengths for which to set the
signal to zero, in samples.
drop_count_low : int
The low end of number of times that the signal
can be dropped to zero.
drop_count_high : int
The high end of number of times that the signal
can be dropped to zero.
drop_start : int
The first index for which dropping will be allowed.
drop_end : int
The last index for which dropping will be allowed.
drop_prob : float
The probability that the batch of signals will
have a portion dropped. By default, every batch
has portions dropped.
noise_factor : float
The factor relative to average amplitude of an utterance
to use for scaling the white noise inserted. 1 keeps
the average amplitude the same, while 0 inserts all 0's.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> dropper = DropChunk(drop_start=100, drop_end=200, noise_factor=0.)
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> signal = signal.unsqueeze(0) # [batch, time, channels]
>>> length = torch.ones(1)
>>> dropped_signal = dropper(signal, length)
>>> float(dropped_signal[:, 150])
0.0
"""
def __init__(
self,
drop_length_low=100,
drop_length_high=1000,
drop_count_low=1,
drop_count_high=10,
drop_start=0,
drop_end=None,
drop_prob=1,
noise_factor=0.0,
):
super().__init__()
self.drop_length_low = drop_length_low
self.drop_length_high = drop_length_high
self.drop_count_low = drop_count_low
self.drop_count_high = drop_count_high
self.drop_start = drop_start
self.drop_end = drop_end
self.drop_prob = drop_prob
self.noise_factor = noise_factor
# Validate low < high
if drop_length_low > drop_length_high:
raise ValueError("Low limit must not be more than high limit")
if drop_count_low > drop_count_high:
raise ValueError("Low limit must not be more than high limit")
# Make sure the length doesn't exceed end - start
if drop_end is not None and drop_end >= 0:
if drop_start > drop_end:
raise ValueError("Low limit must not be more than high limit")
drop_range = drop_end - drop_start
self.drop_length_low = min(drop_length_low, drop_range)
self.drop_length_high = min(drop_length_high, drop_range)
def forward(self, waveforms, lengths):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or
`[batch, time, channels]`
"""
# Reading input list
lengths = (lengths * waveforms.size(1)).long()
batch_size = waveforms.size(0)
dropped_waveform = waveforms.clone()
# Don't drop (return early) 1-`drop_prob` portion of the batches
if torch.rand(1) > self.drop_prob:
return dropped_waveform
# Store original amplitude for computing white noise amplitude
clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1))
# Pick a number of times to drop
drop_times = torch.randint(
low=self.drop_count_low,
high=self.drop_count_high + 1,
size=(batch_size,),
)
# Iterate batch to set mask
for i in range(batch_size):
if drop_times[i] == 0:
continue
# Pick lengths
length = torch.randint(
low=self.drop_length_low,
high=self.drop_length_high + 1,
size=(drop_times[i],),
)
# Compute range of starting locations
start_min = self.drop_start
if start_min < 0:
start_min += lengths[i]
start_max = self.drop_end
if start_max is None:
start_max = lengths[i]
if start_max < 0:
start_max += lengths[i]
start_max = max(0, start_max - length.max())
# Pick starting locations
start = torch.randint(
low=start_min, high=start_max + 1, size=(drop_times[i],),
)
end = start + length
# Update waveform
if not self.noise_factor:
for j in range(drop_times[i]):
dropped_waveform[i, start[j] : end[j]] = 0.0
else:
# Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization
noise_max = 2 * clean_amplitude[i] * self.noise_factor
for j in range(drop_times[i]):
# zero-center the noise distribution
noise_vec = torch.rand(length[j], device=waveforms.device)
noise_vec = 2 * noise_max * noise_vec - noise_max
dropped_waveform[i, start[j] : end[j]] = noise_vec
return

@ -0,0 +1,45 @@
"""Vanilla Neural Network for simple tests.
Authors
* Elena Rastorgueva 2020
"""
import paddle
from paddlespeech.s2t.models.wav2vec2.speechbrain.nnet import containers
import paddlespeech.s2t.models.wav2vec2.speechbrain as sb
class VanillaNN(containers.Sequential):
"""A simple vanilla Deep Neural Network.
Arguments
---------
activation : paddle class
A class used for constructing the activation layers.
dnn_blocks : int
The number of linear neural blocks to include.
dnn_neurons : int
The number of neurons in the linear layers.
Example
-------
>>> inputs = paddle.rand([10, 120, 60])
>>> model = VanillaNN(input_shape=inputs.shape)
>>> outputs = model(inputs)
>>> outputs.shape
paddle.shape([10, 120, 512])
"""
def __init__(
self,
input_shape,
activation=paddle.nn.LeakyReLU,
dnn_blocks=2,
dnn_neurons=512,
):
super().__init__(input_shape=input_shape)
for block_index in range(dnn_blocks):
self.append(
sb.nnet.linear.Linear,
n_neurons=dnn_neurons,
bias=True,
layer_name="linear",
)
self.append(activation(), layer_name="act")

@ -0,0 +1,2 @@
from . import linear
from . import containers

@ -0,0 +1,132 @@
import paddle
import inspect
import logging
import operator
import functools
class Sequential(paddle.nn.LayerDict):
"""A sequence of modules with potentially inferring shape on construction.
If layers are passed with names, these can be referenced with dot notation.
Arguments
---------
input_shape : iterable
A list or tuple of ints or None, representing the expected shape of an
input tensor. None represents a variable-length dimension. If no
``input_shape`` is passed, no shape inference will be performed.
*layers, **named_layers
The inputs are treated as a list of layers to be
applied in sequence. The output shape of each layer is used to
infer the shape of the following layer. If a tuple is returned,
only the shape of the first element is used to determine input
shape of the next layer (e.g. RNN returns output, hidden).
Example
-------
>>> inputs = paddle.rand(10, 40, 50)
>>> model = Sequential(input_shape=inputs.shape)
>>> model.append(Linear, n_neurons=100, layer_name="layer1")
>>> model.append(Linear, n_neurons=200, layer_name="layer2")
>>> outputs = model(inputs)
>>> outputs.shape
paddle.shape([10, 40, 200])
>>> outputs = model.layer1(inputs)
>>> outputs.shape
paddle.shape([10, 40, 100])
"""
def __init__(self, *layers, input_shape=None, **named_layers):
super().__init__()
# Make sure either layers or input_shape is passed
if not layers and input_shape is None and not named_layers:
raise ValueError("Must pass either layers or input shape")
# Keep track of what layers need "lengths" passed
self.length_layers = []
# Replace None dimensions with arbitrary value
self.input_shape = input_shape
if input_shape and None in input_shape:
self.input_shape = list(input_shape)
for i, dim in enumerate(self.input_shape):
# To reduce size of dummy tensors, use 1 for batch dim
if i == 0 and dim is None:
dim = 1
# Use 64 as nice round arbitrary value, big enough that
# halving this dimension a few times doesn't reach 1
self.input_shape[i] = dim or 256
# Append non-named layers
for layer in layers:
self.append(layer)
# Append named layers
for name, layer in named_layers.items():
self.append(layer, layer_name=name)
def append(self, layer, *args, layer_name=None, **kwargs):
"""Add a layer to the list of layers, inferring shape if necessary.
Arguments
---------
layer : A paddle.nn.Module class or object
If the layer is a class, it should accept an argument called
``input_shape`` which will be inferred and passed. If the layer
is a module object, it is added as-is.
layer_name : str
The name of the layer, for reference. If the name is in use,
``_{count}`` will be appended.
*args, **kwargs
These are passed to the layer if it is constructed.
"""
# Compute layer_name
if layer_name is None:
layer_name = str(len(self))
elif layer_name in self:
index = 0
while f"{layer_name}_{index}" in self:
index += 1
layer_name = f"{layer_name}_{index}"
# Check if it needs to be constructed with input shape
if self.input_shape:
argspec = inspect.getfullargspec(layer)
if "input_shape" in argspec.args + argspec.kwonlyargs:
input_shape = self.get_output_shape()
layer = layer(*args, input_shape=input_shape, **kwargs)
# Finally, append the layer.
try:
self[layer_name] = layer
# self.add_module(layer_name, layer)
except TypeError:
raise ValueError(
"Must pass `input_shape` at initialization and use "
"modules that take `input_shape` to infer shape when "
"using `append()`."
)
def get_output_shape(self):
"""Returns expected shape of the output.
Computed by passing dummy input constructed with the
``self.input_shape`` attribute.
"""
with paddle.no_grad():
dummy_input = paddle.zeros(self.input_shape)
dummy_output = self(dummy_input)
return dummy_output.shape
def forward(self, x):
"""Applies layers in sequence, passing only the first element of tuples.
Arguments
---------
x : paddle.Tensor
The input tensor to run through the network.
"""
for layer in self.values():
x = layer(x)
if isinstance(x, tuple):
x = x[0]
return x

@ -0,0 +1,73 @@
"""Library implementing linear transformation.
Authors
* Mirco Ravanelli 2020
* Davide Borra 2021
"""
import logging
import paddle
import paddle.nn as nn
from paddlespeech.s2t.modules import align
logger = logging.getLogger(__name__)
class Linear(paddle.nn.Layer):
"""Computes a linear transformation y = wx + b.
Arguments
---------
n_neurons : int
It is the number of output neurons (i.e, the dimensionality of the
output).
input_shape: tuple
It is the shape of the input tensor.
input_size: int
Size of the input tensor.
bias : bool
If True, the additive bias b is adopted.
combine_dims : bool
If True and the input is 4D, combine 3rd and 4th dimensions of input.
Example
-------
>>> inputs = paddle.rand(10, 50, 40)
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
>>> output = lin_t(inputs)
>>> output.shape
paddle.shape([10, 50, 100])
"""
def __init__(
self,
n_neurons,
input_shape=None,
input_size=None,
bias=True,
combine_dims=False,
):
super().__init__()
self.combine_dims = combine_dims
if input_shape is None and input_size is None:
raise ValueError("Expected one of input_shape or input_size")
if input_size is None:
input_size = input_shape[-1]
if len(input_shape) == 4 and self.combine_dims:
input_size = input_shape[2] * input_shape[3]
# Weights are initialized following paddle approach
self.w = align.Linear(input_size, n_neurons, bias_attr=bias)
def forward(self, x):
"""Returns the linear transformation of input tensor.
Arguments
---------
x : paddle.Tensor
Input to transform linearly.
"""
if x.rank == 4 and self.combine_dims:
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
wx = self.w(x)
return wx

@ -0,0 +1,256 @@
"""
Low level signal processing utilities
Authors
* Peter Plantinga 2020
* Francois Grondin 2020
* William Aris 2020
* Samuele Cornell 2020
* Sarthak Yadav 2022
"""
import paddle
import math
from packaging import version
import numpy as np
def blackman_window(window_length, periodic=True):
if window_length == 0:
return []
if window_length == 1:
return paddle.ones([1])
if periodic:
window_length += 1
window = paddle.arange(window_length) * (np.pi / (window_length - 1))
window = 0.08 * paddle.cos(window * 4) - 0.5 * paddle.cos(window * 2) + 0.42
return window[:-1] if periodic else window
def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
"""Compute amplitude of a batch of waveforms.
Arguments
---------
waveform : tensor
The waveforms used for computing amplitude.
Shape should be `[time]` or `[batch, time]` or
`[batch, time, channels]`.
lengths : tensor
The lengths of the waveforms excluding the padding.
Shape should be a single dimension, `[batch]`.
amp_type : str
Whether to compute "avg" average or "peak" amplitude.
Choose between ["avg", "peak"].
scale : str
Whether to compute amplitude in "dB" or "linear" scale.
Choose between ["linear", "dB"].
Returns
-------
The average amplitude of the waveforms.
Example
-------
>>> signal = torch.sin(torch.arange(16000.0)).unsqueeze(0)
>>> compute_amplitude(signal, signal.size(1))
tensor([[0.6366]])
"""
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0)
assert amp_type in ["avg", "peak"]
assert scale in ["linear", "dB"]
if amp_type == "avg":
if lengths is None:
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
else:
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
out = wav_sum / lengths
elif amp_type == "peak":
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)[0]
else:
raise NotImplementedError
if scale == "linear":
return out
elif scale == "dB":
return paddle.clip(20 * paddle.log10(out), min=-80) # clamp zeros
else:
raise NotImplementedError
def convolve1d(
waveform,
kernel,
padding=0,
pad_type="constant",
stride=1,
groups=1,
use_fft=False,
rotation_index=0,
):
"""Use torch.nn.functional to perform 1d padding and conv.
Arguments
---------
waveform : tensor
The tensor to perform operations on.
kernel : tensor
The filter to apply during convolution.
padding : int or tuple
The padding (pad_left, pad_right) to apply.
If an integer is passed instead, this is passed
to the conv1d function and pad_type is ignored.
pad_type : str
The type of padding to use. Passed directly to
`torch.nn.functional.pad`, see PyTorch documentation
for available options.
stride : int
The number of units to move each time convolution is applied.
Passed to conv1d. Has no effect if `use_fft` is True.
groups : int
This option is passed to `conv1d` to split the input into groups for
convolution. Input channels should be divisible by the number of groups.
use_fft : bool
When `use_fft` is passed `True`, then compute the convolution in the
spectral domain using complex multiply. This is more efficient on CPU
when the size of the kernel is large (e.g. reverberation). WARNING:
Without padding, circular convolution occurs. This makes little
difference in the case of reverberation, but may make more difference
with different kernels.
rotation_index : int
This option only applies if `use_fft` is true. If so, the kernel is
rolled by this amount before convolution to shift the output location.
Returns
-------
The convolved waveform.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> signal = signal.unsqueeze(0).unsqueeze(2)
>>> kernel = torch.rand(1, 10, 1)
>>> signal = convolve1d(signal, kernel, padding=(9, 0))
"""
if len(waveform.shape) != 3:
raise ValueError("Convolve1D expects a 3-dimensional tensor")
# Move time dimension last, which pad and fft and conv expect.
waveform = waveform.transpose([0, 2, 1])
kernel = kernel.transpose([0, 2, 1])
# Padding can be a tuple (left_pad, right_pad) or an int
if isinstance(padding, tuple):
waveform = paddle.nn.functional.pad(
x=waveform, pad=padding, mode=pad_type,
)
# This approach uses FFT, which is more efficient if the kernel is large
if use_fft:
# Pad kernel to same length as signal, ensuring correct alignment
zero_length = waveform.shape[-1] - kernel.shape[-1]
# Handle case where signal is shorter
if zero_length < 0:
kernel = kernel[..., :zero_length]
zero_length = 0
# Perform rotation to ensure alignment
zeros = paddle.zeros(
kernel.shape[0], kernel.shape[1], zero_length
)
after_index = kernel[..., rotation_index:]
before_index = kernel[..., :rotation_index]
kernel = paddle.concat((after_index, zeros, before_index), axis=-1)
# Multiply in frequency domain to convolve in time domain
# if version.parse(torch.__version__) > version.parse("1.6.0"):
import paddle.fft as fft
result = fft.rfft(waveform) * fft.rfft(kernel)
convolved = fft.irfft(result, n=waveform.shape[-1])
# else:
# f_signal = torch.rfft(waveform, 1)
# f_kernel = torch.rfft(kernel, 1)
# sig_real, sig_imag = f_signal.unbind(-1)
# ker_real, ker_imag = f_kernel.unbind(-1)
# f_result = torch.stack(
# [
# sig_real * ker_real - sig_imag * ker_imag,
# sig_real * ker_imag + sig_imag * ker_real,
# ],
# dim=-1,
# )
# convolved = torch.irfft(
# f_result, 1, signal_sizes=[waveform.size(-1)]
# )
# Use the implementation given by torch, which should be efficient on GPU
else:
convolved = paddle.nn.functional.conv1d(
x=waveform,
weight=kernel,
stride=stride,
groups=groups,
padding=padding if not isinstance(padding, tuple) else 0,
)
# Return time dimension to the second dimension.
return convolved.transpose([0, 2, 1])
def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
"""Returns a notch filter constructed from a high-pass and low-pass filter.
(from https://tomroelandts.com/articles/
how-to-create-simple-band-pass-and-band-reject-filters)
Arguments
---------
notch_freq : float
frequency to put notch as a fraction of the
sampling rate / 2. The range of possible inputs is 0 to 1.
filter_width : int
Filter width in samples. Longer filters have
smaller transition bands, but are more inefficient.
notch_width : float
Width of the notch, as a fraction of the sampling_rate / 2.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> signal = signal.unsqueeze(0).unsqueeze(2)
>>> kernel = notch_filter(0.25)
>>> notched_signal = convolve1d(signal, kernel)
"""
# Check inputs
assert 0 < notch_freq <= 1
assert filter_width % 2 != 0
pad = filter_width // 2
inputs = paddle.arange(filter_width) - pad
# Avoid frequencies that are too low
notch_freq += notch_width
# Define sinc function, avoiding division by zero
def sinc(x):
"Computes the sinc function."
def _sinc(x):
return paddle.sin(x) / x
# The zero is at the middle index
return paddle.concat([_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1 :])])
# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
hlpf *= blackman_window(filter_width)
hlpf /= paddle.sum(hlpf)
# Compute a high-pass filter with cutoff frequency notch_freq.
hhpf = sinc(3 * (notch_freq + notch_width) * inputs)
hhpf *= blackman_window(filter_width)
hhpf /= -paddle.sum(hhpf)
hhpf[pad] += 1
# Adding filters creates notch filter
return (hlpf + hhpf).view(1, -1, 1)

@ -0,0 +1,741 @@
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.speechbrain.processing.signal_processing import (
compute_amplitude,
convolve1d,
notch_filter)
import pdb
class SpeedPerturb(nn.Layer):
"""Slightly speed up or slow down an audio signal.
Resample the audio signal at a rate that is similar to the original rate,
to achieve a slightly slower or slightly faster signal. This technique is
outlined in the paper: "Audio Augmentation for Speech Recognition"
Arguments
---------
orig_freq : int
The frequency of the original signal.
speeds : list
The speeds that the signal should be changed to, as a percentage of the
original signal (i.e. `speeds` is divided by 100 to get a ratio).
perturb_prob : float
The chance that the batch will be speed-
perturbed. By default, every batch is perturbed.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> perturbator = SpeedPerturb(orig_freq=16000, speeds=[90])
>>> clean = signal.unsqueeze(0)
>>> perturbed = perturbator(clean)
>>> clean.shape
torch.Size([1, 52173])
>>> perturbed.shape
torch.Size([1, 46956])
"""
def __init__(
self, orig_freq, speeds=[90, 100, 110], perturb_prob=1.0,
):
super().__init__()
self.orig_freq = orig_freq
self.speeds = speeds
self.perturb_prob = perturb_prob
# Initialize index of perturbation
self.samp_index = 0
# Initialize resamplers
self.resamplers = []
for speed in self.speeds:
config = {
"orig_freq": self.orig_freq,
"new_freq": self.orig_freq * speed // 100,
}
self.resamplers.append(Resample(**config))
def forward(self, waveform):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
"""
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
if paddle.rand([1]) > self.perturb_prob:
return waveform.clone()
# Perform a random perturbation
self.samp_index = paddle.randint(len(self.speeds), shape=(1,))[0]
perturbed_waveform = self.resamplers[self.samp_index](waveform)
return perturbed_waveform
class Resample(nn.Layer):
"""This class resamples an audio signal using sinc-based interpolation.
It is a modification of the `resample` function from torchaudio
(https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html)
Arguments
---------
orig_freq : int
the sampling frequency of the input signal.
new_freq : int
the new sampling frequency after this operation is performed.
lowpass_filter_width : int
Controls the sharpness of the filter, larger numbers result in a
sharper filter, but they are less efficient. Values from 4 to 10 are allowed.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> signal = signal.unsqueeze(0) # [batch, time, channels]
>>> resampler = Resample(orig_freq=16000, new_freq=8000)
>>> resampled = resampler(signal)
>>> signal.shape
torch.Size([1, 52173])
>>> resampled.shape
torch.Size([1, 26087])
"""
def __init__(
self, orig_freq=16000, new_freq=16000, lowpass_filter_width=6,
):
super().__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.lowpass_filter_width = lowpass_filter_width
# Compute rate for striding
self._compute_strides()
assert self.orig_freq % self.conv_stride == 0
assert self.new_freq % self.conv_transpose_stride == 0
def _compute_strides(self):
"""Compute the phases in polyphase filter.
(almost directly from torchaudio.compliance.kaldi)
"""
# Compute new unit based on ratio of in/out frequencies
base_freq = math.gcd(self.orig_freq, self.new_freq)
input_samples_in_unit = self.orig_freq // base_freq
self.output_samples = self.new_freq // base_freq
# Store the appropriate stride based on the new units
self.conv_stride = input_samples_in_unit
self.conv_transpose_stride = self.output_samples
def forward(self, waveforms):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
"""
if not hasattr(self, "first_indices"):
self._indices_and_weights(waveforms)
# Don't do anything if the frequencies are the same
if self.orig_freq == self.new_freq:
return waveforms
unsqueezed = False
if len(waveforms.shape) == 2:
waveforms = waveforms.unsqueeze(1)
unsqueezed = True
elif len(waveforms.shape) == 3:
waveforms = waveforms.transpose([0, 2, 1])
else:
raise ValueError("Input must be 2 or 3 dimensions")
# Do resampling
resampled_waveform = self._perform_resample(waveforms)
if unsqueezed:
resampled_waveform = resampled_waveform.squeeze(1)
else:
resampled_waveform = resampled_waveform.transpose([0, 2, 1])
return resampled_waveform
def _perform_resample(self, waveforms):
"""Resamples the waveform at the new frequency.
This matches Kaldi's OfflineFeatureTpl ResampleWaveform which uses a
LinearResample (resample a signal at linearly spaced intervals to
up/downsample a signal). LinearResample (LR) means that the output
signal is at linearly spaced intervals (i.e the output signal has a
frequency of `new_freq`). It uses sinc/bandlimited interpolation to
upsample/downsample the signal.
(almost directly from torchaudio.compliance.kaldi)
https://ccrma.stanford.edu/~jos/resample/
Theory_Ideal_Bandlimited_Interpolation.html
https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
Arguments
---------
waveforms : tensor
The batch of audio signals to resample.
Returns
-------
The waveforms at the new frequency.
"""
# Compute output size and initialize
batch_size, num_channels, wave_len = waveforms.shape
window_size = self.weights.shape[1]
tot_output_samp = self._output_samples(wave_len)
resampled_waveform = paddle.zeros(
(batch_size, num_channels, tot_output_samp)
)
# self.weights = self.weights.to(waveforms.device)
# Check weights are on correct device
# if waveforms.device != self.weights.device:
# self.weights = self.weights.to(waveforms.device)
# eye size: (num_channels, num_channels, 1)
eye = paddle.eye(num_channels).unsqueeze(2)
# Iterate over the phases in the polyphase filter
for i in range(self.first_indices.shape[0]):
wave_to_conv = waveforms
first_index = int(self.first_indices[i].item())
if first_index >= 0:
# trim the signal as the filter will not be applied
# before the first_index
wave_to_conv = wave_to_conv[..., first_index:]
# pad the right of the signal to allow partial convolutions
# meaning compute values for partial windows (e.g. end of the
# window is outside the signal length)
max_index = (tot_output_samp - 1) // self.output_samples
end_index = max_index * self.conv_stride + window_size
current_wave_len = wave_len - first_index
right_padding = max(0, end_index + 1 - current_wave_len)
left_padding = max(0, -first_index)
wave_to_conv = paddle.nn.functional.pad(
wave_to_conv, (left_padding, right_padding), data_format='NCL'
)
conv_wave = paddle.nn.functional.conv1d(
x=wave_to_conv,
weight=self.weights[i].repeat(num_channels, 1, 1),
stride=self.conv_stride,
groups=num_channels,
)
# we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride]
dilated_conv_wave = paddle.nn.functional.conv1d_transpose(
conv_wave, eye, stride=self.conv_transpose_stride
)
# pad dilated_conv_wave so it reaches the output length if needed.
left_padding = i
previous_padding = left_padding + dilated_conv_wave.shape[-1]
right_padding = max(0, tot_output_samp - previous_padding)
dilated_conv_wave = paddle.nn.functional.pad(
dilated_conv_wave, (left_padding, right_padding), data_format='NCL'
)
dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp]
resampled_waveform += dilated_conv_wave
return resampled_waveform
def _output_samples(self, input_num_samp):
"""Based on LinearResample::GetNumOutputSamples.
LinearResample (LR) means that the output signal is at
linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited
interpolation to upsample/downsample the signal.
(almost directly from torchaudio.compliance.kaldi)
Arguments
---------
input_num_samp : int
The number of samples in each example in the batch.
Returns
-------
Number of samples in the output waveform.
"""
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
# where tick_freq is the least common multiple of samp_in and
# samp_out.
samp_in = int(self.orig_freq)
samp_out = int(self.new_freq)
tick_freq = abs(samp_in * samp_out) // math.gcd(samp_in, samp_out)
ticks_per_input_period = tick_freq // samp_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_in ).
interval_length = input_num_samp * ticks_per_input_period
if interval_length <= 0:
return 0
ticks_per_output_period = tick_freq // samp_out
# Get the last output-sample in the closed interval,
# i.e. replacing [ ) with [ ]. Note: integer division rounds down.
# See http://en.wikipedia.org/wiki/Interval_(mathematics) for an
# explanation of the notation.
last_output_samp = interval_length // ticks_per_output_period
# We need the last output-sample in the open interval, so if it
# takes us to the end of the interval exactly, subtract one.
if last_output_samp * ticks_per_output_period == interval_length:
last_output_samp -= 1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp = last_output_samp + 1
return num_output_samp
def _indices_and_weights(self, waveforms):
"""Based on LinearResample::SetIndexesAndWeights
Retrieves the weights for resampling as well as the indices in which
they are valid. LinearResample (LR) means that the output signal is at
linearly spaced intervals (i.e the output signal has a frequency
of ``new_freq``). It uses sinc/bandlimited interpolation to
upsample/downsample the signal.
Returns
-------
- the place where each filter should start being applied
- the filters to be applied to the signal for resampling
"""
# Lowpass filter frequency depends on smaller of two frequencies
min_freq = min(self.orig_freq, self.new_freq)
lowpass_cutoff = 0.99 * 0.5 * min_freq
assert lowpass_cutoff * 2 <= min_freq
window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
output_t = paddle.arange(
start=0.0, end=self.output_samples
)
output_t /= self.new_freq
min_t = output_t - window_width
max_t = output_t + window_width
min_input_index = paddle.ceil(min_t * self.orig_freq)
max_input_index = paddle.floor(max_t * self.orig_freq)
num_indices = max_input_index - min_input_index + 1
max_weight_width = num_indices.max()
j = paddle.arange(max_weight_width)
input_index = min_input_index.unsqueeze(1) + j.unsqueeze(0)
delta_t = (input_index / self.orig_freq) - output_t.unsqueeze(1)
weights = paddle.zeros_like(delta_t)
inside_window_indices = delta_t.abs() < (window_width)
# raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * (
1
+ paddle.cos(
2
* math.pi
* lowpass_cutoff
/ self.lowpass_filter_width
* delta_t[inside_window_indices]
)
)
t_eq_zero_indices = delta_t == 0.0
t_not_eq_zero_indices = ~t_eq_zero_indices
# sinc filter function
weights[t_not_eq_zero_indices] *= paddle.sin(
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]
) / (math.pi * delta_t[t_not_eq_zero_indices])
# limit of the function at t = 0
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
# size (output_samples, max_weight_width)
weights /= self.orig_freq
self.first_indices = min_input_index
self.weights = weights
class DropFreq(nn.Layer):
"""This class drops a random frequency from the signal.
The purpose of this class is to teach models to learn to rely on all parts
of the signal, not just a few frequency bands.
Arguments
---------
drop_freq_low : float
The low end of frequencies that can be dropped,
as a fraction of the sampling rate / 2.
drop_freq_high : float
The high end of frequencies that can be
dropped, as a fraction of the sampling rate / 2.
drop_count_low : int
The low end of number of frequencies that could be dropped.
drop_count_high : int
The high end of number of frequencies that could be dropped.
drop_width : float
The width of the frequency band to drop, as
a fraction of the sampling_rate / 2.
drop_prob : float
The probability that the batch of signals will have a frequency
dropped. By default, every batch has frequencies dropped.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> dropper = DropFreq()
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> dropped_signal = dropper(signal.unsqueeze(0))
"""
def __init__(
self,
drop_freq_low=1e-14,
drop_freq_high=1,
drop_count_low=1,
drop_count_high=2,
drop_width=0.05,
drop_prob=1,
):
super().__init__()
self.drop_freq_low = drop_freq_low
self.drop_freq_high = drop_freq_high
self.drop_count_low = drop_count_low
self.drop_count_high = drop_count_high
self.drop_width = drop_width
self.drop_prob = drop_prob
def forward(self, waveforms):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
"""
# Don't drop (return early) 1-`drop_prob` portion of the batches
dropped_waveform = waveforms.clone()
if paddle.rand([1]) > self.drop_prob:
return dropped_waveform
# Add channels dimension
if len(waveforms.shape) == 2:
dropped_waveform = dropped_waveform.unsqueeze(-1)
# Pick number of frequencies to drop
drop_count = paddle.randint(
low=self.drop_count_low, high=self.drop_count_high + 1, shape=(1,),
)
# Pick a frequency to drop
drop_range = self.drop_freq_high - self.drop_freq_low
drop_frequency = (
paddle.rand(drop_count) * drop_range + self.drop_freq_low
)
# Filter parameters
filter_length = 101
pad = filter_length // 2
# Start with delta function
drop_filter = paddle.zeros([1, filter_length, 1])
drop_filter[0, pad, 0] = 1
# Subtract each frequency
for frequency in drop_frequency:
notch_kernel = notch_filter(
frequency, filter_length, self.drop_width,
)
drop_filter = convolve1d(drop_filter, notch_kernel, pad)
# Apply filter
dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad)
# Remove channels dimension if added
return dropped_waveform.squeeze(-1)
class DropChunk(nn.Layer):
"""This class drops portions of the input signal.
Using `DropChunk` as an augmentation strategy helps a models learn to rely
on all parts of the signal, since it can't expect a given part to be
present.
Arguments
---------
drop_length_low : int
The low end of lengths for which to set the
signal to zero, in samples.
drop_length_high : int
The high end of lengths for which to set the
signal to zero, in samples.
drop_count_low : int
The low end of number of times that the signal
can be dropped to zero.
drop_count_high : int
The high end of number of times that the signal
can be dropped to zero.
drop_start : int
The first index for which dropping will be allowed.
drop_end : int
The last index for which dropping will be allowed.
drop_prob : float
The probability that the batch of signals will
have a portion dropped. By default, every batch
has portions dropped.
noise_factor : float
The factor relative to average amplitude of an utterance
to use for scaling the white noise inserted. 1 keeps
the average amplitude the same, while 0 inserts all 0's.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> dropper = DropChunk(drop_start=100, drop_end=200, noise_factor=0.)
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> signal = signal.unsqueeze(0) # [batch, time, channels]
>>> length = torch.ones(1)
>>> dropped_signal = dropper(signal, length)
>>> float(dropped_signal[:, 150])
0.0
"""
def __init__(
self,
drop_length_low=100,
drop_length_high=1000,
drop_count_low=1,
drop_count_high=10,
drop_start=0,
drop_end=None,
drop_prob=1,
noise_factor=0.0,
):
super().__init__()
self.drop_length_low = drop_length_low
self.drop_length_high = drop_length_high
self.drop_count_low = drop_count_low
self.drop_count_high = drop_count_high
self.drop_start = drop_start
self.drop_end = drop_end
self.drop_prob = drop_prob
self.noise_factor = noise_factor
# Validate low < high
if drop_length_low > drop_length_high:
raise ValueError("Low limit must not be more than high limit")
if drop_count_low > drop_count_high:
raise ValueError("Low limit must not be more than high limit")
# Make sure the length doesn't exceed end - start
if drop_end is not None and drop_end >= 0:
if drop_start > drop_end:
raise ValueError("Low limit must not be more than high limit")
drop_range = drop_end - drop_start
self.drop_length_low = min(drop_length_low, drop_range)
self.drop_length_high = min(drop_length_high, drop_range)
def forward(self, waveforms, lengths):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or
`[batch, time, channels]`
"""
# Reading input list
lengths = (lengths * waveforms.shape[1]).long()
batch_size = waveforms.shape[0]
dropped_waveform = waveforms.clone()
# Don't drop (return early) 1-`drop_prob` portion of the batches
if paddle.rand([1]) > self.drop_prob:
return dropped_waveform
# Store original amplitude for computing white noise amplitude
clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1))
# Pick a number of times to drop
drop_times = paddle.randint(
low=self.drop_count_low,
high=self.drop_count_high + 1,
shape=(batch_size,),
)
# Iterate batch to set mask
for i in range(batch_size):
if drop_times[i] == 0:
continue
# Pick lengths
length = paddle.randint(
low=self.drop_length_low,
high=self.drop_length_high + 1,
shape=(drop_times[i],),
)
# Compute range of starting locations
start_min = self.drop_start
if start_min < 0:
start_min += lengths[i]
start_max = self.drop_end
if start_max is None:
start_max = lengths[i]
if start_max < 0:
start_max += lengths[i]
start_max = max(0, start_max - length.max())
# Pick starting locations
start = paddle.randint(
low=start_min, high=start_max + 1, shape=(drop_times[i],),
)
end = start + length
# Update waveform
if not self.noise_factor:
for j in range(drop_times[i]):
dropped_waveform[i, start[j] : end[j]] = 0.0
else:
# Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization
noise_max = 2 * clean_amplitude[i] * self.noise_factor
for j in range(drop_times[i]):
# zero-center the noise distribution
noise_vec = paddle.rand(length[j])
noise_vec = 2 * noise_max * noise_vec - noise_max
dropped_waveform[i, start[j] : end[j]] = noise_vec
return dropped_waveform
class TimeDomainSpecAugment(nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm.
This augmentation module implements three augmentations in
the time-domain.
1. Drop chunks of the audio (zero amplitude or white noise)
2. Drop frequency bands (with band-drop filters)
3. Speed peturbation (via resampling to slightly different rate)
Arguments
---------
perturb_prob : float from 0 to 1
The probability that a batch will have speed perturbation applied.
drop_freq_prob : float from 0 to 1
The probability that a batch will have frequencies dropped.
drop_chunk_prob : float from 0 to 1
The probability that a batch will have chunks dropped.
speeds : list of ints
A set of different speeds to use to perturb each batch.
See ``speechbrain.processing.speech_augmentation.SpeedPerturb``
sample_rate : int
Sampling rate of the input waveforms.
drop_freq_count_low : int
Lowest number of frequencies that could be dropped.
drop_freq_count_high : int
Highest number of frequencies that could be dropped.
drop_chunk_count_low : int
Lowest number of chunks that could be dropped.
drop_chunk_count_high : int
Highest number of chunks that could be dropped.
drop_chunk_length_low : int
Lowest length of chunks that could be dropped.
drop_chunk_length_high : int
Highest length of chunks that could be dropped.
drop_chunk_noise_factor : float
The noise factor used to scale the white noise inserted, relative to
the average amplitude of the utterance. Default 0 (no noise inserted).
Example
-------
>>> inputs = torch.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, torch.ones(10))
>>> feats.shape
torch.Size([10, 12800])
"""
def __init__(
self,
perturb_prob=1.0,
drop_freq_prob=1.0,
drop_chunk_prob=1.0,
speeds=[95, 100, 105],
sample_rate=16000,
drop_freq_count_low=0,
drop_freq_count_high=3,
drop_chunk_count_low=0,
drop_chunk_count_high=5,
drop_chunk_length_low=1000,
drop_chunk_length_high=2000,
drop_chunk_noise_factor=0,
):
super().__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds
)
self.drop_freq = DropFreq(
drop_prob=drop_freq_prob,
drop_count_low=drop_freq_count_low,
drop_count_high=drop_freq_count_high,
)
self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor,
)
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
Arguments
---------
waveforms : torch.Tensor
The waveforms to distort
"""
# Augmentation
with paddle.no_grad():
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return waveforms

@ -0,0 +1,14 @@
import paddle
import numpy as np
def blackman_window(window_length, periodic=True):
if window_length == 0:
return []
if window_length == 1:
return paddle.ones([1])
if periodic:
window_length += 1
window = paddle.arange(window_length) * (np.pi / (window_length - 1))
window = 0.08 * paddle.cos(window * 4) - 0.5 * paddle.cos(window * 2) + 0.42
return window[:-1] if periodic else window

@ -0,0 +1,287 @@
import numpy as np
import os
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ConfigPure
from paddlespeech.s2t.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
from paddlespeech.s2t.modules.mask import make_pad_mask
from paddlespeech.s2t.utils.utility import log_add
from collections import defaultdict
from paddlespeech.s2t.models.wav2vec2.speechbrain.lobes.models.VanillaNN import VanillaNN
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from yacs.config import CfgNode
class Wav2vec2ASR(nn.Layer):
def __init__(self, config: dict):
super().__init__()
wav2vec2_config = Wav2Vec2ConfigPure()
wav2vec2 = Wav2Vec2Model(wav2vec2_config)
model_dict = paddle.load(config.wav2vec2_params_path)
wav2vec2.set_state_dict(model_dict)
wav2vec2.eval()
self.normalize_wav = config.normalize_wav
self.output_norm = config.output_norm
if config.freeze_wav2vec2:
for parm in wav2vec2.parameters():
parm.trainable = False
self.wav2vec2 = wav2vec2
self.enc = VanillaNN(input_shape=[None,None,wav2vec2_config.hidden_size], activation=nn.LeakyReLU, dnn_blocks=config.dnn_blocks, dnn_neurons=config.dnn_neurons)
self.ctc = CTC(odim=config.output_dim, enc_n_units=config.dnn_neurons, blank_id=config.blank_id, dropout_rate=config.ctc_dropout_rate, reduction_type="mean")
def train_batch(self):
wav, wavs_lens_rate, target, target_lens_rate = self._get_data()
ctc_loss = self(wav, wavs_lens_rate, target, target_lens_rate)
def forward(self, wav, wavs_lens_rate, target, target_lens_rate):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
np.save("data/out.npy", out.numpy())
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
feats = out
x = self.enc(feats)
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
target_lens = (target_lens_rate * target.shape[1]).round().astype(paddle.int64)
ctc_loss = self.ctc(x, x_lens, target, target_lens)
return ctc_loss
@paddle.no_grad()
def decode(self,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
beam_size: int):
batch_size = feats.shape[0]
if decoding_method is 'ctc_prefix_beam_search' and batch_size > 1:
logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1'
)
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
if decoding_method == 'ctc_greedy_search':
hyps = self.ctc_greedy_search(feats, feats_lengths)
res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search(
feats,
feats_lengths,
beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
else:
raise ValueError(f"wav2vec2 not support decoding method: {decoding_method}")
return res, res_tokenids
@classmethod
def from_config(cls, config):
model = cls(config)
return model
def ctc_greedy_search(
self, wav, wavs_lens_rate) -> List[List[int]]:
""" Apply CTC greedy search
Args:
speech (paddle.Tensor): (batch, max_len)
speech_length (paddle.Tensor): (batch, )
Returns:
List[List[int]]: best path result
"""
batch_size = wav.shape[0]
wav = wav[:,:,0]
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
feats = out
x = self.enc(feats)
x_lens = x.shape[1]
ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen)
# pad_mask = make_pad_mask(x_lens) # (B, maxlen)
# topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
return hyps
def _ctc_prefix_beam_search(
self, wav, wavs_lens_rate, beam_size, blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
""" CTC prefix beam search inner implementation
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood)
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
wav = wav[:,:,0]
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
feats = out
x = self.enc(feats)
maxlen = x.shape[1]
ctc_probs = self.ctc.log_softmax(x) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
return hyps
def ctc_prefix_beam_search(self, wav, wavs_lens_rate, beam_size) -> List[int]:
""" Apply CTC prefix beam search
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps = self._ctc_prefix_beam_search(
wav, wavs_lens_rate, beam_size)
return hyps[0][0]
# @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output, (B, T, D)
Returns:
paddle.Tensor: activation before ctc
"""
return self.ctc.log_softmax(xs)
def _get_data(self):
data_dir = "data"
wavs = np.load(os.path.join(data_dir, "wavs.npy"))
wavs_lens = np.load(os.path.join(data_dir, "wavs_lens.npy"))
tokens = np.load(os.path.join(data_dir, "tokens.npy"))
tokens_lens = np.load(os.path.join(data_dir, "tokens_lens.npy"))
batch = (paddle.to_tensor(wavs), paddle.to_tensor(wavs_lens, dtype='float32'),
paddle.to_tensor(tokens, dtype='int32'), paddle.to_tensor(tokens_lens, dtype='float32'))
return batch
if __name__ == "__main__":
# wav2vec2_asr = Wav2vec2ASR(config={})
# wav2vec2_asr.train_batch()
freeze = True
config = Wav2Vec2ConfigPure()
model = Wav2Vec2Model(config)
model_dict = model.state_dict()
revise_params_path = "exp/torch_to_paddle_revise.pdparams"
model_dict_revise = paddle.load(revise_params_path)
model.set_state_dict(model_dict_revise)
model.training = True
model.eval()
if freeze:
for parm in model.parameters():
parm.requires_grad = False
# get enc()
enc = VanillaNN(input_shape=[None,None,1024], activation=paddle.nn.LeakyReLU, dnn_blocks=2, dnn_neurons=1024)
ctc = CTC(odim=30, enc_n_units=1024, blank_id=0, dropout_rate=0.0)
input_values = np.load("input_values.npy")
input_values = paddle.to_tensor(input_values)
feats = model(input_values).last_hidden_state
x = enc(feats)
ctc_loss = ctc(enc, target)

@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import paddle import paddle
from paddle import nn from paddle import nn
import math
""" """
To align the initializer between paddle and torch, To align the initializer between paddle and torch,
the API below are set defalut initializer with priority higger than global initializer. the API below are set defalut initializer with priority higger than global initializer.
@ -82,18 +81,10 @@ class Linear(nn.Linear):
name=None): name=None):
if weight_attr is None: if weight_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
if bias_attr is None: if bias_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr( bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
super(Linear, self).__init__(in_features, out_features, weight_attr, super(Linear, self).__init__(in_features, out_features, weight_attr,
bias_attr, name) bias_attr, name)
@ -113,18 +104,10 @@ class Conv1D(nn.Conv1D):
data_format='NCL'): data_format='NCL'):
if weight_attr is None: if weight_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
if bias_attr is None: if bias_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr( bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
super(Conv1D, self).__init__( super(Conv1D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format) groups, padding_mode, weight_attr, bias_attr, data_format)
@ -145,18 +128,10 @@ class Conv2D(nn.Conv2D):
data_format='NCHW'): data_format='NCHW'):
if weight_attr is None: if weight_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
if bias_attr is None: if bias_attr is None:
if global_init_type == "kaiming_uniform": if global_init_type == "kaiming_uniform":
bias_attr = paddle.ParamAttr( bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
initializer=nn.initializer.KaimingUniform(
fan_in=None,
negative_slope=math.sqrt(5),
nonlinearity='leaky_relu'))
super(Conv2D, self).__init__( super(Conv2D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation, in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format) groups, padding_mode, weight_attr, bias_attr, data_format)

@ -15,6 +15,7 @@
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Multi-Head Attention layer definition.""" """Multi-Head Attention layer definition."""
import math import math
from typing import Optional
from typing import Tuple from typing import Tuple
import paddle import paddle
@ -82,11 +83,10 @@ class MultiHeadedAttention(nn.Layer):
return q, k, v return q, k, v
def forward_attention( def forward_attention(self,
self,
value: paddle.Tensor, value: paddle.Tensor,
scores: paddle.Tensor, scores: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool) mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool),
) -> paddle.Tensor: ) -> paddle.Tensor:
"""Compute attention context vector. """Compute attention context vector.
Args: Args:
@ -127,14 +127,13 @@ class MultiHeadedAttention(nn.Layer):
return self.linear_out(x) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model)
def forward( def forward(self,
self,
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
value: paddle.Tensor, value: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool) mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
pos_emb: paddle.Tensor, # paddle.empty([0]) pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor # paddle.zeros([0,0,0,0]) cache: paddle.Tensor = paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention. """Compute scaled dot product attention.
Args: Args:
@ -244,14 +243,13 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
return x return x
def forward( def forward(self,
self,
query: paddle.Tensor, query: paddle.Tensor,
key: paddle.Tensor, key: paddle.Tensor,
value: paddle.Tensor, value: paddle.Tensor,
mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool) mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
pos_emb: paddle.Tensor, # paddle.empty([0]) pos_emb: paddle.Tensor = paddle.empty([0]),
cache: paddle.Tensor # paddle.zeros([0,0,0,0]) cache: paddle.Tensor = paddle.zeros([0,0,0,0])
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding. """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args: Args:

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
"""ConvolutionModule definition.""" """ConvolutionModule definition."""
from typing import Optional
from typing import Tuple from typing import Tuple
import paddle import paddle
@ -105,11 +106,10 @@ class ConvolutionModule(nn.Layer):
) )
self.activation = activation self.activation = activation
def forward( def forward(self,
self,
x: paddle.Tensor, x: paddle.Tensor,
mask_pad: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool) mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
cache: paddle.Tensor # paddle.zeros([0,0,0,0]) cache: paddle.Tensor= paddle.zeros([0,0,0]),
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module. """Compute convolution module.
Args: Args:

@ -53,7 +53,7 @@ class CTCDecoderBase(nn.Layer):
enc_n_units, enc_n_units,
blank_id=0, blank_id=0,
dropout_rate: float=0.0, dropout_rate: float=0.0,
reduction: bool=True, reduction_type: str="sum",
batch_average: bool=True, batch_average: bool=True,
grad_norm_type: Union[str, None]=None): grad_norm_type: Union[str, None]=None):
"""CTC decoder """CTC decoder
@ -73,7 +73,7 @@ class CTCDecoderBase(nn.Layer):
self.odim = odim self.odim = odim
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
self.ctc_lo = Linear(enc_n_units, self.odim) self.ctc_lo = Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none" reduction_type = reduction_type if reduction_type else "none"
self.criterion = CTCLoss( self.criterion = CTCLoss(
blank=self.blank_id, blank=self.blank_id,
reduction=reduction_type, reduction=reduction_type,

@ -121,16 +121,11 @@ class DecoderLayer(nn.Layer):
if self.concat_after: if self.concat_after:
tgt_concat = paddle.cat( tgt_concat = paddle.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]), dim=-1)
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear1(tgt_concat) x = residual + self.concat_linear1(tgt_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[
0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
@ -139,15 +134,11 @@ class DecoderLayer(nn.Layer):
x = self.norm2(x) x = self.norm2(x)
if self.concat_after: if self.concat_after:
x_concat = paddle.cat( x_concat = paddle.cat(
(x, self.src_attn(x, memory, memory, memory_mask, (x, self.src_attn(x, memory, memory, memory_mask)[0]), dim=-1)
paddle.empty([0]),
paddle.zeros([0, 0, 0, 0]))[0]),
dim=-1)
x = residual + self.concat_linear2(x_concat) x = residual + self.concat_linear2(x_concat)
else: else:
x = residual + self.dropout( x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask, self.src_attn(x, memory, memory, memory_mask)[0])
paddle.empty([0]), paddle.zeros([0, 0, 0, 0]))[0])
if not self.normalize_before: if not self.normalize_before:
x = self.norm2(x) x = self.norm2(x)

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet) # Modified from wenet(https://github.com/wenet-e2e/wenet)
"""Encoder definition.""" """Encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple from typing import Tuple
import paddle import paddle
@ -175,9 +177,7 @@ class BaseEncoder(nn.Layer):
decoding_chunk_size, self.static_chunk_size, decoding_chunk_size, self.static_chunk_size,
num_decoding_left_chunks) num_decoding_left_chunks)
for layer in self.encoders: for layer in self.encoders:
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad, xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
paddle.zeros([0, 0, 0, 0]),
paddle.zeros([0, 0, 0, 0]))
if self.normalize_before: if self.normalize_before:
xs = self.after_norm(xs) xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just # Here we assume the mask is not changed in encoder layers, so just
@ -190,9 +190,9 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor, xs: paddle.Tensor,
offset: int, offset: int,
required_cache_size: int, required_cache_size: int,
att_cache: paddle.Tensor, # paddle.zeros([0,0,0,0]) att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
cnn_cache: paddle.Tensor, # paddle.zeros([0,0,0,0]), cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
att_mask: paddle.Tensor, # paddle.ones([0,0,0], dtype=paddle.bool) att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk """ Forward just one chunk
Args: Args:
@ -252,16 +252,13 @@ class BaseEncoder(nn.Layer):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2) # att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2) # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer( xs, _, new_att_cache, new_cnn_cache = layer(
xs, xs, att_mask, pos_emb,
att_mask, att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
pos_emb, cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
mask_pad=paddle.ones([0, 0, 0], dtype=paddle.bool), )
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
cnn_cache=cnn_cache[i:i + 1]
if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
# new_att_cache = (1, head, attention_key_size, d_k*2) # new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2) # new_cnn_cache = (B=1, hidden-dim, cache_t2)
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) r_att_cache.append(new_att_cache[:,:, next_cache_start:, :])
r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
if self.normalize_before: if self.normalize_before:
@ -273,6 +270,7 @@ class BaseEncoder(nn.Layer):
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk( def forward_chunk_by_chunk(
self, self,
xs: paddle.Tensor, xs: paddle.Tensor,
@ -317,8 +315,8 @@ class BaseEncoder(nn.Layer):
num_frames = xs.shape[1] num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks required_cache_size = decoding_chunk_size * num_decoding_left_chunks
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]) att_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]) cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
outputs = [] outputs = []
offset = 0 offset = 0
@ -328,8 +326,7 @@ class BaseEncoder(nn.Layer):
chunk_xs = xs[:, cur:end, :] chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = self.forward_chunk( (y, att_cache, cnn_cache) = self.forward_chunk(
chunk_xs, offset, required_cache_size, att_cache, cnn_cache, chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y) outputs.append(y)
offset += y.shape[1] offset += y.shape[1]

@ -76,10 +76,9 @@ class TransformerEncoderLayer(nn.Layer):
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: paddle. mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
Tensor, # paddle.ones([0, 0, 0], dtype=paddle.bool) att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:
@ -106,8 +105,7 @@ class TransformerEncoderLayer(nn.Layer):
if self.normalize_before: if self.normalize_before:
x = self.norm1(x) x = self.norm1(x)
x_att, new_att_cache = self.self_attn( x_att, new_att_cache = self.self_attn(x, x, x, mask, cache=att_cache)
x, x, x, mask, paddle.empty([0]), cache=att_cache)
if self.concat_after: if self.concat_after:
x_concat = paddle.concat((x, x_att), axis=-1) x_concat = paddle.concat((x, x_att), axis=-1)
@ -195,9 +193,9 @@ class ConformerEncoderLayer(nn.Layer):
x: paddle.Tensor, x: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
pos_emb: paddle.Tensor, pos_emb: paddle.Tensor,
mask_pad: paddle.Tensor, #paddle.ones([0, 0, 0],dtype=paddle.bool) mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
att_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor, # paddle.zeros([0, 0, 0, 0]) cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute encoded features. """Compute encoded features.
Args: Args:

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
class DefaultInitializerContext(object): class DefaultInitializerContext(object):
""" """

@ -103,6 +103,8 @@ class OptimizerFactory():
grad_clip = ClipGradByGlobalNormWithLog( grad_clip = ClipGradByGlobalNormWithLog(
args['grad_clip']) if "grad_clip" in args else None args['grad_clip']) if "grad_clip" in args else None
# grad_clip = paddle.nn.ClipGradByGlobalNorm(
# args['grad_clip']) if "grad_clip" in args else None
weight_decay = L2Decay( weight_decay = L2Decay(
args['weight_decay']) if "weight_decay" in args else None args['weight_decay']) if "weight_decay" in args else None
if weight_decay: if weight_decay:

@ -106,6 +106,59 @@ class ConstantLR(LRScheduler):
def get_lr(self): def get_lr(self):
return self.base_lr return self.base_lr
@register_scheduler
class NewBobScheduler(LRScheduler):
"""
Args:
learning_rate (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ConstantLR`` instance to schedule learning rate.
"""
def __init__(
self,
learning_rate,
annealing_factor=0.5,
improvement_threshold=0.0025,
patient=0,
):
self.hyperparam_value = learning_rate
self.annealing_factor = annealing_factor
self.improvement_threshold = improvement_threshold
self.patient = patient
self.metric_values = []
self.current_patient = self.patient
def __call__(self, metric_value):
"""Returns the current and new value for the hyperparameter.
Arguments
---------
metric_value : int
A number for determining whether to change the hyperparameter value.
"""
old_value = new_value = self.hyperparam_value
if len(self.metric_values) > 0:
prev_metric = self.metric_values[-1]
# Update value if improvement too small and patience is 0
if prev_metric == 0: # Prevent division by zero
improvement = 0
else:
improvement = (prev_metric - metric_value) / prev_metric
if improvement < self.improvement_threshold:
if self.current_patient == 0:
new_value *= self.annealing_factor
self.current_patient = self.patient
else:
self.current_patient -= 1
# Store relevant info
self.metric_values.append(metric_value)
self.hyperparam_value = new_value
return old_value, new_value
def dynamic_import_scheduler(module): def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically. """Import Scheduler class dynamically.

@ -19,6 +19,8 @@ from pathlib import Path
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
dist.init_parallel_env()
from visualdl import LogWriter from visualdl import LogWriter
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
@ -130,7 +132,9 @@ class Trainer():
latest_n=self.config.checkpoint.latest_n) latest_n=self.config.checkpoint.latest_n)
# set random seed if needed # set random seed if needed
print(args.seed)
if args.seed: if args.seed:
print('***********')
seed_all(args.seed) seed_all(args.seed)
logger.info(f"Set seed {args.seed}") logger.info(f"Set seed {args.seed}")
@ -176,7 +180,7 @@ class Trainer():
def init_parallel(self): def init_parallel(self):
"""Init environment for multiprocess training. """Init environment for multiprocess training.
""" """
dist.init_parallel_env() # dist.init_parallel_env()
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def save(self, tag=None, infos: dict=None): def save(self, tag=None, infos: dict=None):

@ -25,7 +25,6 @@ asr_python:
cfg_path: # [optional] cfg_path: # [optional]
ckpt_path: # [optional] ckpt_path: # [optional]
decode_method: 'attention_rescoring' decode_method: 'attention_rescoring'
num_decoding_left_chunks: -1
force_yes: True force_yes: True
device: # set 'gpu:id' or 'cpu' device: # set 'gpu:id' or 'cpu'
@ -39,7 +38,6 @@ asr_inference:
lang: 'zh' lang: 'zh'
sample_rate: 16000 sample_rate: 16000
cfg_path: cfg_path:
num_decoding_left_chunks: -1
decode_method: decode_method:
force_yes: True force_yes: True

@ -103,9 +103,7 @@ class OnlineCTCEndpoint:
assert self.num_frames_decoded >= self.trailing_silence_frames assert self.num_frames_decoded >= self.trailing_silence_frames
assert self.frame_shift_in_ms > 0 assert self.frame_shift_in_ms > 0
decoding_something = ( decoding_something = (self.num_frames_decoded > self.trailing_silence_frames) and decoding_something
self.num_frames_decoded > self.trailing_silence_frames
) and decoding_something
utterance_length = self.num_frames_decoded * self.frame_shift_in_ms utterance_length = self.num_frames_decoded * self.frame_shift_in_ms
trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms

@ -21,12 +21,12 @@ import paddle
from numpy import float32 from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils import onnx_infer from paddlespeech.server.utils import onnx_infer

@ -21,10 +21,10 @@ import paddle
from numpy import float32 from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig

@ -21,10 +21,10 @@ import paddle
from numpy import float32 from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
@ -130,8 +130,8 @@ class PaddleASRConnectionHanddler:
## conformer ## conformer
# cache for conformer online # cache for conformer online
self.att_cache = paddle.zeros([0, 0, 0, 0]) self.att_cache = paddle.zeros([0,0,0,0])
self.cnn_cache = paddle.zeros([0, 0, 0, 0]) self.cnn_cache = paddle.zeros([0,0,0,0])
self.encoder_out = None self.encoder_out = None
# conformer decoding state # conformer decoding state
@ -474,14 +474,9 @@ class PaddleASRConnectionHanddler:
# cur chunk # cur chunk
chunk_xs = self.cached_feat[:, cur:end, :] chunk_xs = self.cached_feat[:, cur:end, :]
# forward chunk # forward chunk
(y, self.att_cache, (y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk(
self.cnn_cache) = self.model.encoder.forward_chunk( chunk_xs, self.offset, required_cache_size,
chunk_xs, self.att_cache, self.cnn_cache)
self.offset,
required_cache_size,
att_cache=self.att_cache,
cnn_cache=self.cnn_cache,
att_mask=paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y) outputs.append(y)
# update the global offset, in decoding frame unit # update the global offset, in decoding frame unit

@ -68,12 +68,9 @@ class ASREngine(BaseEngine):
return False return False
self.executor._init_from_path( self.executor._init_from_path(
model_type=self.config.model, self.config.model, self.config.lang, self.config.sample_rate,
lang=self.config.lang, self.config.cfg_path, self.config.decode_method,
sample_rate=self.config.sample_rate, self.config.ckpt_path)
cfg_path=self.config.cfg_path,
decode_method=self.config.decode_method,
ckpt_path=self.config.ckpt_path)
logger.info("Initialize ASR server engine successfully on device: %s." % logger.info("Initialize ASR server engine successfully on device: %s." %
(self.device)) (self.device))

@ -105,8 +105,7 @@ class PaddleVectorConnectionHandler:
# we can not reuse the cache io.BytesIO(audio) data, # we can not reuse the cache io.BytesIO(audio) data,
# because the soundfile will change the io.BytesIO(audio) to the end # because the soundfile will change the io.BytesIO(audio) to the end
# thus we should convert the base64 string to io.BytesIO when we need the audio data # thus we should convert the base64 string to io.BytesIO when we need the audio data
if not self.executor._check( if not self.executor._check(io.BytesIO(audio), sample_rate):
io.BytesIO(audio), sample_rate, force_yes=True):
logger.debug("check the audio sample rate occurs error") logger.debug("check the audio sample rate occurs error")
return np.array([0.0]) return np.array([0.0])

@ -11,12 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np import numpy as np
import paddle import paddle
from paddlespeech.t2s.datasets.batch import batch_sequences from paddlespeech.t2s.datasets.batch import batch_sequences
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.modules.nets_utils import get_seg_pos from paddlespeech.t2s.modules.nets_utils import get_seg_pos
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import pad_list
from paddlespeech.t2s.modules.nets_utils import phones_masking from paddlespeech.t2s.modules.nets_utils import phones_masking
from paddlespeech.t2s.modules.nets_utils import phones_text_masking from paddlespeech.t2s.modules.nets_utils import phones_text_masking
@ -485,56 +492,180 @@ def vits_single_spk_batch_fn(examples):
return batch return batch
def vits_multi_spk_batch_fn(examples): # for ERNIE SAT
""" class MLMCollateFn:
Returns: """Functor class of common_collate_fn()"""
Dict[str, Any]:
- text (Tensor): Text index tensor (B, T_text).
- text_lengths (Tensor): Text length tensor (B,).
- feats (Tensor): Feature tensor (B, T_feats, aux_channels).
- feats_lengths (Tensor): Feature length tensor (B,).
- speech (Tensor): Speech waveform tensor (B, T_wav).
- spk_id (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
- spk_emb (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
"""
# fields = ["text", "text_lengths", "feats", "feats_lengths", "speech", "spk_id"/"spk_emb"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
speech = [np.array(item["wave"], dtype=np.float32) for item in examples]
text_lengths = [
np.array(item["text_lengths"], dtype=np.int64) for item in examples
]
feats_lengths = [
np.array(item["feats_lengths"], dtype=np.int64) for item in examples
]
text = batch_sequences(text) def __init__(
feats = batch_sequences(feats) self,
speech = batch_sequences(speech) feats_extract,
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
text_masking: bool=False,
attention_window: int=0,
not_sequence: Collection[str]=(), ):
self.mlm_prob = mlm_prob
self.mean_phn_span = mean_phn_span
self.feats_extract = feats_extract
self.not_sequence = set(not_sequence)
self.attention_window = attention_window
self.seg_emb = seg_emb
self.text_masking = text_masking
# convert each batch to paddle.Tensor def __call__(self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
text = paddle.to_tensor(text) ) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
return mlm_collate_fn(
data,
feats_extract=self.feats_extract,
mlm_prob=self.mlm_prob,
mean_phn_span=self.mean_phn_span,
seg_emb=self.seg_emb,
text_masking=self.text_masking,
not_sequence=self.not_sequence)
def mlm_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
feats_extract=None,
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
text_masking: bool=False,
pad_value: int=0,
not_sequence: Collection[str]=(),
) -> Tuple[List[str], Dict[str, paddle.Tensor]]:
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(not k.endswith("_lens")
for k in data[0]), f"*_lens is reserved: {list(data[0])}"
output = {}
for key in data[0]:
array_list = [d[key] for d in data]
# Assume the first axis is length:
# tensor_list: Batch x (Length, ...)
tensor_list = [paddle.to_tensor(a) for a in array_list]
# tensor: (Batch, Length, ...)
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
# lens: (Batch,)
if key not in not_sequence:
lens = paddle.to_tensor(
[d[key].shape[0] for d in data], dtype=paddle.int64)
output[key + "_lens"] = lens
feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0]))
feats = paddle.to_tensor(feats) feats = paddle.to_tensor(feats)
text_lengths = paddle.to_tensor(text_lengths) print("feats.shape:", feats.shape)
feats_lengths = paddle.to_tensor(feats_lengths) feats_lens = paddle.shape(feats)[0]
feats = paddle.unsqueeze(feats, 0)
batch = { text = output["text"]
"text": text, text_lens = output["text_lens"]
"text_lengths": text_lengths, align_start = output["align_start"]
"feats": feats, align_start_lens = output["align_start_lens"]
"feats_lengths": feats_lengths, align_end = output["align_end"]
"speech": speech
} max_tlen = max(text_lens)
# spk_emb has a higher priority than spk_id max_slen = max(feats_lens)
if "spk_emb" in examples[0]:
spk_emb = [ speech_pad = feats[:, :max_slen]
np.array(item["spk_emb"], dtype=np.float32) for item in examples
] text_pad = text
spk_emb = batch_sequences(spk_emb) text_mask = make_non_pad_mask(
spk_emb = paddle.to_tensor(spk_emb) text_lens, text_pad, length_dim=1).unsqueeze(-2)
batch["spk_emb"] = spk_emb speech_mask = make_non_pad_mask(
elif "spk_id" in examples[0]: feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2)
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
spk_id = paddle.to_tensor(spk_id) span_bdy = None
batch["spk_id"] = spk_id if 'span_bdy' in output.keys():
return batch span_bdy = output['span_bdy']
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if text_masking:
masked_pos, text_masked_pos = phones_text_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
text_pad=text_pad,
text_mask=text_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
# 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了
# a3t 和 ernie sat 的区别主要在于做 mask 的时候
else:
masked_pos = phones_masking(
xs_pad=speech_pad,
src_mask=speech_mask,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
mlm_prob=mlm_prob,
mean_phn_span=mean_phn_span,
span_bdy=span_bdy)
text_masked_pos = paddle.zeros(paddle.shape(text_pad))
output_dict = {}
speech_seg_pos, text_seg_pos = get_seg_pos(
speech_pad=speech_pad,
text_pad=text_pad,
align_start=align_start,
align_end=align_end,
align_start_lens=align_start_lens,
seg_emb=seg_emb)
output_dict['speech'] = speech_pad
output_dict['text'] = text_pad
output_dict['masked_pos'] = masked_pos
output_dict['text_masked_pos'] = text_masked_pos
output_dict['speech_mask'] = speech_mask
output_dict['text_mask'] = text_mask
output_dict['speech_seg_pos'] = speech_seg_pos
output_dict['text_seg_pos'] = text_seg_pos
output = (uttids, output_dict)
return output
def build_mlm_collate_fn(
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
n_mels: int=80,
fmin: int=80,
fmax: int=7600,
mlm_prob: float=0.8,
mean_phn_span: int=8,
seg_emb: bool=False,
epoch: int=-1, ):
feats_extract_class = LogMelFBank
feats_extract = feats_extract_class(
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mels=n_mels,
fmin=fmin,
fmax=fmax)
if epoch == -1:
mlm_prob_factor = 1
else:
mlm_prob_factor = 0.8
return MLMCollateFn(
feats_extract=feats_extract,
mlm_prob=mlm_prob * mlm_prob_factor,
mean_phn_span=mean_phn_span,
seg_emb=seg_emb)

@ -1,9 +1,8 @@
import paddle
import math import math
import numpy as np import numpy as np
from paddle.io import BatchSampler from paddle.io import BatchSampler
class ErnieSATSampler(BatchSampler): class ErnieSATSampler(BatchSampler):
"""Sampler that restricts data loading to a subset of the dataset. """Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance In such case, each process can pass a DistributedBatchSampler instance
@ -71,7 +70,7 @@ class ErnieSATSampler(BatchSampler):
assert isinstance(drop_last, bool), \ assert isinstance(drop_last, bool), \
"drop_last should be a boolean number" "drop_last should be a boolean number"
from paddle.distributed import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
if num_replicas is not None: if num_replicas is not None:
assert isinstance(num_replicas, int) and num_replicas > 0, \ assert isinstance(num_replicas, int) and num_replicas > 0, \
@ -111,8 +110,8 @@ class ErnieSATSampler(BatchSampler):
subsampled_indices.extend(indices[i:i + self.batch_size]) subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:] indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend( subsampled_indices.extend(indices[
indices[self.local_rank * last_local_batch_size:( self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size]) self.local_rank + 1) * last_local_batch_size])
return subsampled_indices return subsampled_indices

@ -19,9 +19,9 @@ import librosa
import numpy as np import numpy as np
import pypinyin import pypinyin
from praatio import textgrid from praatio import textgrid
from paddlespeech.t2s.exps.ernie_sat.utils import get_dict
from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
from paddlespeech.t2s.exps.ernie_sat.utils import get_dict
DICT_EN = 'tools/aligner/cmudict-0.7b' DICT_EN = 'tools/aligner/cmudict-0.7b'
DICT_ZH = 'tools/aligner/simple.lexicon' DICT_ZH = 'tools/aligner/simple.lexicon'
@ -30,7 +30,6 @@ MODEL_DIR_ZH = 'tools/aligner/aishell3_model.zip'
MFA_PATH = 'tools/montreal-forced-aligner/bin' MFA_PATH = 'tools/montreal-forced-aligner/bin'
os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH'] os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH']
def _get_max_idx(dic): def _get_max_idx(dic):
return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]
@ -111,7 +110,7 @@ def alignment(wav_path: str,
tmpbase = './tmp_dir/' + tmp_name tmpbase = './tmp_dir/' + tmp_name
tmpbase = Path(tmpbase) tmpbase = Path(tmpbase)
tmpbase.mkdir(parents=True, exist_ok=True) tmpbase.mkdir(parents=True, exist_ok=True)
print("tmp_name in alignment:", tmp_name) print("tmp_name in alignment:",tmp_name)
shutil.copyfile(wav_path, tmpbase / wav_name) shutil.copyfile(wav_path, tmpbase / wav_name)
txt_name = utt + '.txt' txt_name = utt + '.txt'
@ -341,7 +340,7 @@ def get_phns_spans(wav_path: str,
if __name__ == '__main__': if __name__ == '__main__':
text = "For that reason cover should not be given." text = "For that reason cover should not be given."
phn, dur, word2phns = alignment("source/p243_313.wav", text, lang='en') phn, dur, word2phns = alignment("exp/p243_313.wav", text, lang='en')
print(phn, dur) print(phn, dur)
print(word2phns) print(word2phns)
print("---------------------------------") print("---------------------------------")
@ -353,7 +352,7 @@ if __name__ == '__main__':
style=pypinyin.Style.TONE3, style=pypinyin.Style.TONE3,
tone_sandhi=True) tone_sandhi=True)
text_zh = " ".join(text_zh) text_zh = " ".join(text_zh)
phn, dur, word2phns = alignment("source/000001.wav", text_zh, lang='zh') phn, dur, word2phns = alignment("exp/000001.wav", text_zh, lang='zh')
print(phn, dur) print(phn, dur)
print(word2phns) print(word2phns)
print("---------------------------------") print("---------------------------------")
@ -368,7 +367,7 @@ if __name__ == '__main__':
print("---------------------------------") print("---------------------------------")
outs = get_phns_spans( outs = get_phns_spans(
wav_path="source/p243_313.wav", wav_path="exp/p243_313.wav",
old_str="For that reason cover should not be given.", old_str="For that reason cover should not be given.",
new_str="for that reason cover is impossible to be given.") new_str="for that reason cover is impossible to be given.")

@ -11,41 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import os
from pathlib import Path
from typing import List
import librosa import librosa
import numpy as np import numpy as np
import paddle
import pypinyin
import soundfile as sf import soundfile as sf
import yaml
from pypinyin_dict.phrase_pinyin_data import large_pinyin
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans
from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs
from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor
from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy
from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
from paddlespeech.t2s.exps.syn_utils import get_am_inference from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.exps.syn_utils import norm from paddlespeech.t2s.exps.syn_utils import norm
from paddlespeech.t2s.utils import str2bool from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
large_pinyin.load()
def _p2id(phonemes: List[str]) -> np.ndarray:
def _p2id(self, phonemes: List[str]) -> np.ndarray:
# replace unk phone with sp # replace unk phone with sp
phonemes = [phn if phn in vocab_phones else "sp" for phn in phonemes] phonemes = [
phn if phn in vocab_phones else "sp" for phn in phonemes
]
phone_ids = [vocab_phones[item] for item in phonemes] phone_ids = [vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64) return np.array(phone_ids, np.int64)
def prep_feats_with_dur(wav_path: str, def prep_feats_with_dur(wav_path: str,
old_str: str='', old_str: str='',
new_str: str='', new_str: str='',
@ -73,12 +67,12 @@ def prep_feats_with_dur(wav_path: str,
fs=fs, fs=fs,
n_shift=n_shift) n_shift=n_shift)
mfa_start = phns_spans_outs['mfa_start'] mfa_start = phns_spans_outs["mfa_start"]
mfa_end = phns_spans_outs['mfa_end'] mfa_end = phns_spans_outs["mfa_end"]
old_phns = phns_spans_outs['old_phns'] old_phns = phns_spans_outs["old_phns"]
new_phns = phns_spans_outs['new_phns'] new_phns = phns_spans_outs["new_phns"]
span_to_repl = phns_spans_outs['span_to_repl'] span_to_repl = phns_spans_outs["span_to_repl"]
span_to_add = phns_spans_outs['span_to_add'] span_to_add = phns_spans_outs["span_to_add"]
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
if target_lang in {'en', 'zh'}: if target_lang in {'en', 'zh'}:
@ -138,7 +132,7 @@ def prep_feats_with_dur(wav_path: str,
[wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]]) [wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]])
# 音频是正常遮住了 # 音频是正常遮住了
sf.write(str("mask_wav.wav"), new_wav, samplerate=fs) sf.write(str("new_wav.wav"), new_wav, samplerate=fs)
# 4. get old and new mel span to be mask # 4. get old and new mel span to be mask
old_span_bdy = get_span_bdy( old_span_bdy = get_span_bdy(
@ -158,6 +152,8 @@ def prep_feats_with_dur(wav_path: str,
return outs return outs
def prep_feats(wav_path: str, def prep_feats(wav_path: str,
old_str: str='', old_str: str='',
new_str: str='', new_str: str='',
@ -167,7 +163,7 @@ def prep_feats(wav_path: str,
fs: int=24000, fs: int=24000,
n_shift: int=300): n_shift: int=300):
with_dur_outs = prep_feats_with_dur( outs = prep_feats_with_dur(
wav_path=wav_path, wav_path=wav_path,
old_str=old_str, old_str=old_str,
new_str=new_str, new_str=new_str,
@ -180,14 +176,15 @@ def prep_feats(wav_path: str,
wav_name = os.path.basename(wav_path) wav_name = os.path.basename(wav_path)
utt_id = wav_name.split('.')[0] utt_id = wav_name.split('.')[0]
wav = with_dur_outs['new_wav'] wav = outs['new_wav']
phns = with_dur_outs['new_phns'] phns = outs['new_phns']
mfa_start = with_dur_outs['new_mfa_start'] mfa_start = outs['new_mfa_start']
mfa_end = with_dur_outs['new_mfa_end'] mfa_end = outs['new_mfa_end']
old_span_bdy = with_dur_outs['old_span_bdy'] old_span_bdy = outs['old_span_bdy']
new_span_bdy = with_dur_outs['new_span_bdy'] new_span_bdy = outs['new_span_bdy']
span_bdy = np.array(new_span_bdy) span_bdy = np.array(new_span_bdy)
text = _p2id(phns)
mel = mel_extractor.get_log_mel_fbank(wav) mel = mel_extractor.get_log_mel_fbank(wav)
erniesat_mean, erniesat_std = np.load(erniesat_stat) erniesat_mean, erniesat_std = np.load(erniesat_stat)
normed_mel = norm(mel, erniesat_mean, erniesat_std) normed_mel = norm(mel, erniesat_mean, erniesat_std)
@ -195,225 +192,122 @@ def prep_feats(wav_path: str,
tmpbase = './tmp_dir/' + tmp_name tmpbase = './tmp_dir/' + tmp_name
tmpbase = Path(tmpbase) tmpbase = Path(tmpbase)
tmpbase.mkdir(parents=True, exist_ok=True) tmpbase.mkdir(parents=True, exist_ok=True)
print("tmp_name in synthesize_e2e:",tmp_name)
mel_path = tmpbase / 'mel.npy' mel_path = tmpbase / 'mel.npy'
np.save(mel_path, normed_mel) print("mel_path:",mel_path)
np.save(mel_path, logmel)
durations = [e - s for e, s in zip(mfa_end, mfa_start)] durations = [e - s for e, s in zip(mfa_end, mfa_start)]
text = _p2id(phns)
datum = { datum={
"utt_id": utt_id, "utt_id": utt_id,
"spk_id": 0, "spk_id": 0,
"text": text, "text": text,
"text_lengths": len(text), "text_lengths": len(text),
"speech_lengths": len(normed_mel), "speech_lengths": 115,
"durations": durations, "durations": durations,
"speech": np.load(mel_path), "speech": mel_path,
"align_start": mfa_start, "align_start": mfa_start,
"align_end": mfa_end, "align_end": mfa_end,
"span_bdy": span_bdy "span_bdy": span_bdy
} }
batch = collate_fn([datum]) batch = collate_fn([datum])
outs = dict() print("batch:",batch)
outs['batch'] = batch
outs['old_span_bdy'] = old_span_bdy return batch, old_span_bdy, new_span_bdy
outs['new_span_bdy'] = new_span_bdy
return outs
def get_mlm_output(wav_path: str, def decode_with_model(mlm_model: nn.Layer,
collate_fn,
wav_path: str,
old_str: str='', old_str: str='',
new_str: str='', new_str: str='',
source_lang: str='en', source_lang: str='en',
target_lang: str='en', target_lang: str='en',
use_teacher_forcing: bool=False,
duration_adjust: bool=True, duration_adjust: bool=True,
fs: int=24000, fs: int=24000,
n_shift: int=300): n_shift: int=300,
token_list: List[str]=[]):
prep_feats_outs = prep_feats( batch, old_span_bdy, new_span_bdy = prep_feats(
source_lang=source_lang,
target_lang=target_lang,
wav_path=wav_path, wav_path=wav_path,
old_str=old_str, old_str=old_str,
new_str=new_str, new_str=new_str,
source_lang=source_lang,
target_lang=target_lang,
duration_adjust=duration_adjust, duration_adjust=duration_adjust,
fs=fs, fs=fs,
n_shift=n_shift) n_shift=n_shift,
token_list=token_list)
batch = prep_feats_outs['batch'] feats = collate_fn(batch)[1]
new_span_bdy = prep_feats_outs['new_span_bdy']
old_span_bdy = prep_feats_outs['old_span_bdy']
out_mels = erniesat_inference( if 'text_masked_pos' in feats.keys():
speech=batch['speech'], feats.pop('text_masked_pos')
text=batch['text'],
masked_pos=batch['masked_pos'], output = mlm_model.inference(
speech_mask=batch['speech_mask'], text=feats['text'],
text_mask=batch['text_mask'], speech=feats['speech'],
speech_seg_pos=batch['speech_seg_pos'], masked_pos=feats['masked_pos'],
text_seg_pos=batch['text_seg_pos'], speech_mask=feats['speech_mask'],
span_bdy=new_span_bdy) text_mask=feats['text_mask'],
speech_seg_pos=feats['speech_seg_pos'],
text_seg_pos=feats['text_seg_pos'],
span_bdy=new_span_bdy,
use_teacher_forcing=use_teacher_forcing)
# 拼接音频 # 拼接音频
output_feat = paddle.concat(x=out_mels, axis=0) output_feat = paddle.concat(x=output, axis=0)
wav_org, _ = librosa.load(wav_path, sr=fs) wav_org, _ = librosa.load(wav_path, sr=fs)
outs = dict() return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
outs['wav_org'] = wav_org
outs['output_feat'] = output_feat
outs['old_span_bdy'] = old_span_bdy
outs['new_span_bdy'] = new_span_bdy
return outs
def get_wav(wav_path: str, if __name__ == '__main__':
source_lang: str='en', fs = 24000
target_lang: str='en', n_shift = 300
old_str: str='', wav_path = "exp/p243_313.wav"
new_str: str='', old_str = "For that reason cover should not be given."
duration_adjust: bool=True, # for edit
fs: int=24000, # new_str = "for that reason cover is impossible to be given."
n_shift: int=300): # for synthesize
append_str = "do you love me i love you so much"
new_str = old_str + append_str
outs = get_mlm_output( '''
outs = prep_feats_with_dur(
wav_path=wav_path, wav_path=wav_path,
old_str=old_str, old_str=old_str,
new_str=new_str, new_str=new_str,
source_lang=source_lang,
target_lang=target_lang,
duration_adjust=duration_adjust,
fs=fs, fs=fs,
n_shift=n_shift) n_shift=n_shift)
wav_org = outs['wav_org'] new_wav = outs['new_wav']
output_feat = outs['output_feat'] new_phns = outs['new_phns']
new_mfa_start = outs['new_mfa_start']
new_mfa_end = outs['new_mfa_end']
old_span_bdy = outs['old_span_bdy'] old_span_bdy = outs['old_span_bdy']
new_span_bdy = outs['new_span_bdy'] new_span_bdy = outs['new_span_bdy']
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] print("---------------------------------")
with paddle.no_grad():
alt_wav = voc_inference(masked_feat)
alt_wav = np.squeeze(alt_wav)
old_time_bdy = [n_shift * x for x in old_span_bdy]
wav_replaced = np.concatenate(
[wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]])
wav_dict = {"origin": wav_org, "output": wav_replaced}
return wav_dict
def parse_args():
# parse args and config
parser = argparse.ArgumentParser(
description="Synthesize with acoustic model & vocoder")
# ernie sat
parser.add_argument(
'--erniesat_config',
type=str,
default=None,
help='Config of acoustic model.')
parser.add_argument(
'--erniesat_ckpt',
type=str,
default=None,
help='Checkpoint file of acoustic model.')
parser.add_argument(
"--erniesat_stat",
type=str,
default=None,
help="mean and standard deviation used to normalize spectrogram when training acoustic model."
)
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
# vocoder
parser.add_argument(
'--voc',
type=str,
default='pwgan_csmsc',
choices=[
'pwgan_aishell3',
'pwgan_vctk',
'hifigan_aishell3',
'hifigan_vctk',
],
help='Choose vocoder type of tts task.')
parser.add_argument(
'--voc_config', type=str, default=None, help='Config of voc.')
parser.add_argument(
'--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument(
"--voc_stat",
type=str,
default=None,
help="mean and standard deviation used to normalize spectrogram when training voc."
)
# other
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
# ernie sat related
parser.add_argument("--task_name", type=str, help="task name")
parser.add_argument("--wav_path", type=str, help="path of old wav")
parser.add_argument("--old_str", type=str, help="old string")
parser.add_argument("--new_str", type=str, help="new string")
parser.add_argument(
"--source_lang", type=str, default="en", help="source language")
parser.add_argument(
"--target_lang", type=str, default="en", help="target language")
parser.add_argument(
"--duration_adjust",
type=str2bool,
default=True,
help="whether to adjust duration.")
parser.add_argument("--output_name", type=str, default="output.wav")
args = parser.parse_args()
return args
print("new_wav:", new_wav)
print("new_phns:", new_phns)
print("new_mfa_start:", new_mfa_start)
print("new_mfa_end:", new_mfa_end)
print("old_span_bdy:", old_span_bdy)
print("new_span_bdy:", new_span_bdy)
print("---------------------------------")
'''
if __name__ == '__main__': erniesat_config = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/local/default.yaml"
args = parse_args()
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
# evaluate(args) with open(erniesat_config) as f:
with open(args.erniesat_config) as f:
erniesat_config = CfgNode(yaml.safe_load(f)) erniesat_config = CfgNode(yaml.safe_load(f))
old_str = args.old_str
new_str = args.new_str erniesat_stat = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy"
# convert Chinese characters to pinyin
if args.source_lang == 'zh':
old_str = pypinyin.lazy_pinyin(
old_str,
neutral_tone_with_five=True,
style=pypinyin.Style.TONE3,
tone_sandhi=True)
old_str = ' '.join(old_str)
if args.target_lang == 'zh':
new_str = pypinyin.lazy_pinyin(
new_str,
neutral_tone_with_five=True,
style=pypinyin.Style.TONE3,
tone_sandhi=True)
new_str = ' '.join(new_str)
if args.task_name == 'edit':
new_str = new_str
elif args.task_name == 'synthesize':
new_str = old_str + new_str
else:
new_str = old_str + new_str
print("new_str:", new_str)
# Extractor # Extractor
mel_extractor = LogMelFBank( mel_extractor = LogMelFBank(
@ -426,50 +320,27 @@ if __name__ == '__main__':
fmin=erniesat_config.fmin, fmin=erniesat_config.fmin,
fmax=erniesat_config.fmax) fmax=erniesat_config.fmax)
collate_fn = build_erniesat_collate_fn( collate_fn = build_erniesat_collate_fn(
mlm_prob=erniesat_config.mlm_prob, mlm_prob=erniesat_config.mlm_prob,
mean_phn_span=erniesat_config.mean_phn_span, mean_phn_span=erniesat_config.mean_phn_span,
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
text_masking=False) text_masking=False)
phones_dict='/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt'
vocab_phones = {} vocab_phones = {}
with open(args.phones_dict, 'rt') as f: with open(phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id: for phn, id in phn_id:
vocab_phones[phn] = int(id) vocab_phones[phn] = int(id)
# ernie sat model prep_feats(wav_path=wav_path,
erniesat_inference = get_am_inference(
am='erniesat_dataset',
am_config=erniesat_config,
am_ckpt=args.erniesat_ckpt,
am_stat=args.erniesat_stat,
phones_dict=args.phones_dict)
with open(args.voc_config) as f:
voc_config = CfgNode(yaml.safe_load(f))
# vocoder
voc_inference = get_voc_inference(
voc=args.voc,
voc_config=voc_config,
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
erniesat_stat = args.erniesat_stat
wav_dict = get_wav(
wav_path=args.wav_path,
source_lang=args.source_lang,
target_lang=args.target_lang,
old_str=old_str, old_str=old_str,
new_str=new_str, new_str=new_str,
duration_adjust=args.duration_adjust, fs=fs,
fs=erniesat_config.fs, n_shift=n_shift)
n_shift=erniesat_config.n_shift)
sf.write(
args.output_name, wav_dict['output'], samplerate=erniesat_config.fs)
print(
f"\033[1;32;m Generated audio saved into {args.output_name} ! \033[0m")

@ -25,6 +25,7 @@ from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle import nn from paddle import nn
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam from paddle.optimizer import Adam
from yacs.config import CfgNode from yacs.config import CfgNode

@ -11,35 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import os
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Union from typing import Union
import os
import numpy as np import numpy as np
import paddle import paddle
import yaml import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
import hashlib
from paddlespeech.t2s.exps.syn_utils import get_am_inference from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference
def _get_user(): def _get_user():
return os.path.expanduser('~').split('/')[-1] return os.path.expanduser('~').split('/')[-1]
def str2md5(string): def str2md5(string):
md5_val = hashlib.md5(string.encode('utf8')).hexdigest() md5_val = hashlib.md5(string.encode('utf8')).hexdigest()
return md5_val return md5_val
def get_tmp_name(text:str):
def get_tmp_name(text: str):
return _get_user() + '_' + str(os.getpid()) + '_' + str2md5(text) return _get_user() + '_' + str(os.getpid()) + '_' + str2md5(text)
def get_dict(dictfile: str): def get_dict(dictfile: str):
word2phns_dict = {} word2phns_dict = {}
with open(dictfile, 'r') as fid: with open(dictfile, 'r') as fid:

@ -82,10 +82,6 @@ def denorm(data, mean, std):
return data * std + mean return data * std + mean
def norm(data, mean, std):
return (data - mean) / std
def get_chunks(data, block_size: int, pad_size: int): def get_chunks(data, block_size: int, pad_size: int):
data_len = data.shape[1] data_len = data.shape[1]
chunks = [] chunks = []
@ -298,8 +294,8 @@ def am_to_static(am_inference,
am_name = am[:am.rindex('_')] am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:] am_dataset = am[am.rindex('_') + 1:]
if am_name == 'fastspeech2': if am_name == 'fastspeech2':
if am_dataset in {"aishell3", "vctk", if am_dataset in {"aishell3", "vctk", "mix"
"mix"} and speaker_dict is not None: } and speaker_dict is not None:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[
@ -311,8 +307,8 @@ def am_to_static(am_inference,
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
elif am_name == 'speedyspeech': elif am_name == 'speedyspeech':
if am_dataset in {"aishell3", "vctk", if am_dataset in {"aishell3", "vctk", "mix"
"mix"} and speaker_dict is not None: } and speaker_dict is not None:
am_inference = jit.to_static( am_inference = jit.to_static(
am_inference, am_inference,
input_spec=[ input_spec=[

@ -15,7 +15,6 @@ import argparse
from pathlib import Path from pathlib import Path
import jsonlines import jsonlines
import numpy as np
import paddle import paddle
import soundfile as sf import soundfile as sf
import yaml import yaml
@ -24,7 +23,6 @@ from yacs.config import CfgNode
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.vits import VITS from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.utils import str2bool
def evaluate(args): def evaluate(args):
@ -42,26 +40,8 @@ def evaluate(args):
print(config) print(config)
fields = ["utt_id", "text"] fields = ["utt_id", "text"]
converters = {}
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker vits!")
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("Evaluating voice cloning!")
fields += ["spk_emb"]
else:
print("single speaker vits!")
print("spk_num:", spk_num)
test_dataset = DataTable( test_dataset = DataTable(data=test_metadata, fields=fields)
data=test_metadata,
fields=fields,
converters=converters, )
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
@ -69,7 +49,6 @@ def evaluate(args):
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1 odim = config.n_fft // 2 + 1
config["model"]["generator_params"]["spks"] = spk_num
vits = VITS(idim=vocab_size, odim=odim, **config["model"]) vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"]) vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
@ -86,15 +65,7 @@ def evaluate(args):
phone_ids = paddle.to_tensor(datum["text"]) phone_ids = paddle.to_tensor(datum["text"])
with timer() as t: with timer() as t:
with paddle.no_grad(): with paddle.no_grad():
spk_emb = None out = vits.inference(text=phone_ids)
spk_id = None
# multi speaker
if args.voice_cloning and "spk_emb" in datum:
spk_emb = paddle.to_tensor(np.load(datum["spk_emb"]))
elif "spk_id" in datum:
spk_id = paddle.to_tensor(datum["spk_id"])
out = vits.inference(
text=phone_ids, sids=spk_id, spembs=spk_emb)
wav = out["wav"] wav = out["wav"]
wav = wav.numpy() wav = wav.numpy()
N += wav.size N += wav.size
@ -119,13 +90,6 @@ def parse_args():
'--ckpt', type=str, default=None, help='Checkpoint file of VITS.') '--ckpt', type=str, default=None, help='Checkpoint file of VITS.')
parser.add_argument( parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.") "--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--voice-cloning",
type=str2bool,
default=False,
help="whether training voice cloning model.")
# other # other
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")

@ -42,23 +42,12 @@ def evaluate(args):
# frontend # frontend
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker vits!")
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
else:
print("single speaker vits!")
print("spk_num:", spk_num)
with open(args.phones_dict, "r") as f: with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()] phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id) vocab_size = len(phn_id)
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1 odim = config.n_fft // 2 + 1
config["model"]["generator_params"]["spks"] = spk_num
vits = VITS(idim=vocab_size, odim=odim, **config["model"]) vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"]) vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
@ -89,10 +78,7 @@ def evaluate(args):
flags = 0 flags = 0
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i] part_phone_ids = phone_ids[i]
spk_id = None out = vits.inference(text=part_phone_ids)
if spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id)
out = vits.inference(text=part_phone_ids, sids=spk_id)
wav = out["wav"] wav = out["wav"]
if flags == 0: if flags == 0:
wav_all = wav wav_all = wav
@ -123,13 +109,6 @@ def parse_args():
'--ckpt', type=str, default=None, help='Checkpoint file of VITS.') '--ckpt', type=str, default=None, help='Checkpoint file of VITS.')
parser.add_argument( parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.") "--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
'--spk_id',
type=int,
default=0,
help='spk id for multi speaker acoustic model')
# other # other
parser.add_argument( parser.add_argument(
'--lang', '--lang',

@ -28,7 +28,6 @@ from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam from paddle.optimizer import Adam
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.vits import VITS from paddlespeech.t2s.models.vits import VITS
@ -44,7 +43,6 @@ from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import scheduler_classes from paddlespeech.t2s.training.optimizer import scheduler_classes
from paddlespeech.t2s.training.seeding import seed_everything from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer from paddlespeech.t2s.training.trainer import Trainer
from paddlespeech.t2s.utils import str2bool
def train_sp(args, config): def train_sp(args, config):
@ -74,23 +72,6 @@ def train_sp(args, config):
"wave": np.load, "wave": np.load,
"feats": np.load, "feats": np.load,
} }
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker vits!")
collate_fn = vits_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("Training voice cloning!")
collate_fn = vits_multi_spk_batch_fn
fields += ["spk_emb"]
converters["spk_emb"] = np.load
else:
print("single speaker vits!")
collate_fn = vits_single_spk_batch_fn
print("spk_num:", spk_num)
# construct dataset for training and validation # construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader: with jsonlines.open(args.train_metadata, 'r') as reader:
@ -119,16 +100,18 @@ def train_sp(args, config):
drop_last=False) drop_last=False)
print("samplers done!") print("samplers done!")
train_batch_fn = vits_single_spk_batch_fn
train_dataloader = DataLoader( train_dataloader = DataLoader(
train_dataset, train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
collate_fn=collate_fn, collate_fn=train_batch_fn,
num_workers=config.num_workers) num_workers=config.num_workers)
dev_dataloader = DataLoader( dev_dataloader = DataLoader(
dev_dataset, dev_dataset,
batch_sampler=dev_sampler, batch_sampler=dev_sampler,
collate_fn=collate_fn, collate_fn=train_batch_fn,
num_workers=config.num_workers) num_workers=config.num_workers)
print("dataloaders done!") print("dataloaders done!")
@ -138,7 +121,6 @@ def train_sp(args, config):
print("vocab_size:", vocab_size) print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1 odim = config.n_fft // 2 + 1
config["model"]["generator_params"]["spks"] = spk_num
model = VITS(idim=vocab_size, odim=odim, **config["model"]) model = VITS(idim=vocab_size, odim=odim, **config["model"])
gen_parameters = model.generator.parameters() gen_parameters = model.generator.parameters()
dis_parameters = model.discriminator.parameters() dis_parameters = model.discriminator.parameters()
@ -258,17 +240,6 @@ def main():
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument( parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.") "--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
parser.add_argument(
"--voice-cloning",
type=str2bool,
default=False,
help="whether training voice cloning model.")
args = parser.parse_args() args = parser.parse_args()

@ -21,28 +21,13 @@ import soundfile as sf
import yaml import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.cli.vector import VectorExecutor
from paddlespeech.t2s.exps.syn_utils import get_am_inference from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.utils import str2bool
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
def gen_random_embed(use_ecapa: bool=False):
if use_ecapa:
# Randomly generate numbers of -25 ~ 25, 192 is the dim of spk_emb
random_spk_emb = (-1 + 2 * np.random.rand(192)) * 25
# GE2E
else:
# Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb
random_spk_emb = np.random.rand(256) * 0.2
random_spk_emb = paddle.to_tensor(random_spk_emb, dtype='float32')
return random_spk_emb
def voice_cloning(args): def voice_cloning(args):
# Init body. # Init body.
with open(args.am_config) as f: with open(args.am_config) as f:
@ -56,20 +41,7 @@ def voice_cloning(args):
print(am_config) print(am_config)
print(voc_config) print(voc_config)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
input_dir = Path(args.input_dir)
# speaker encoder # speaker encoder
if args.use_ecapa:
vec_executor = VectorExecutor()
# warm up
vec_executor(
audio_file=input_dir / os.listdir(input_dir)[0], force_yes=True)
print("ECAPA-TDNN Done!")
# use GE2E
else:
p = SpeakerVerificationPreprocessor( p = SpeakerVerificationPreprocessor(
sampling_rate=16000, sampling_rate=16000,
audio_norm_target_dBFS=-30, audio_norm_target_dBFS=-30,
@ -93,10 +65,6 @@ def voice_cloning(args):
frontend = Frontend(phone_vocab_path=args.phones_dict) frontend = Frontend(phone_vocab_path=args.phones_dict)
print("frontend done!") print("frontend done!")
sentence = args.text
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"][0]
# acoustic model # acoustic model
am_inference = get_am_inference( am_inference = get_am_inference(
am=args.am, am=args.am,
@ -112,19 +80,26 @@ def voice_cloning(args):
voc_ckpt=args.voc_ckpt, voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat) voc_stat=args.voc_stat)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
input_dir = Path(args.input_dir)
sentence = args.text
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"][0]
for name in os.listdir(input_dir): for name in os.listdir(input_dir):
utt_id = name.split(".")[0] utt_id = name.split(".")[0]
ref_audio_path = input_dir / name ref_audio_path = input_dir / name
if args.use_ecapa: mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path))
spk_emb = vec_executor(audio_file=ref_audio_path, force_yes=True) # print("mel_sequences: ", mel_sequences.shape)
spk_emb = paddle.to_tensor(spk_emb)
# GE2E
else:
mel_sequences = p.extract_mel_partials(
p.preprocess_wav(ref_audio_path))
with paddle.no_grad(): with paddle.no_grad():
spk_emb = speaker_encoder.embed_utterance( spk_emb = speaker_encoder.embed_utterance(
paddle.to_tensor(mel_sequences)) paddle.to_tensor(mel_sequences))
# print("spk_emb shape: ", spk_emb.shape)
with paddle.no_grad(): with paddle.no_grad():
wav = voc_inference(am_inference(phone_ids, spk_emb=spk_emb)) wav = voc_inference(am_inference(phone_ids, spk_emb=spk_emb))
@ -133,15 +108,14 @@ def voice_cloning(args):
wav.numpy(), wav.numpy(),
samplerate=am_config.fs) samplerate=am_config.fs)
print(f"{utt_id} done!") print(f"{utt_id} done!")
# Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb
# generate 5 random_spk_emb random_spk_emb = np.random.rand(256) * 0.2
for i in range(5): random_spk_emb = paddle.to_tensor(random_spk_emb, dtype='float32')
random_spk_emb = gen_random_embed(args.use_ecapa)
utt_id = "random_spk_emb" utt_id = "random_spk_emb"
with paddle.no_grad(): with paddle.no_grad():
wav = voc_inference(am_inference(phone_ids, spk_emb=random_spk_emb)) wav = voc_inference(am_inference(phone_ids, spk_emb=random_spk_emb))
sf.write( sf.write(
str(output_dir / (utt_id + "_" + str(i) + ".wav")), str(output_dir / (utt_id + ".wav")),
wav.numpy(), wav.numpy(),
samplerate=am_config.fs) samplerate=am_config.fs)
print(f"{utt_id} done!") print(f"{utt_id} done!")
@ -197,15 +171,13 @@ def parse_args():
type=str, type=str,
default="每当你觉得,想要批评什么人的时候,你切要记着,这个世界上的人,并非都具备你禀有的条件。", default="每当你觉得,想要批评什么人的时候,你切要记着,这个世界上的人,并非都具备你禀有的条件。",
help="text to synthesize, a line") help="text to synthesize, a line")
parser.add_argument( parser.add_argument(
"--ge2e_params_path", type=str, help="ge2e params path.") "--ge2e_params_path", type=str, help="ge2e params path.")
parser.add_argument(
"--use_ecapa",
type=str2bool,
default=False,
help="whether to use ECAPA-TDNN as speaker encoder.")
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
parser.add_argument( parser.add_argument(
"--input-dir", "--input-dir",
type=str, type=str,

@ -1 +1,2 @@
from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter

@ -81,12 +81,12 @@ def prepare_onnx_input(tokenizer,
position_ids.append(position_id) position_ids.append(position_id)
outputs = { outputs = {
'input_ids': np.array(input_ids).astype(np.int64), 'input_ids': np.array(input_ids),
'token_type_ids': np.array(token_type_ids).astype(np.int64), 'token_type_ids': np.array(token_type_ids),
'attention_masks': np.array(attention_masks).astype(np.int64), 'attention_masks': np.array(attention_masks),
'phoneme_masks': np.array(phoneme_masks).astype(np.float32), 'phoneme_masks': np.array(phoneme_masks).astype(np.float32),
'char_ids': np.array(char_ids).astype(np.int64), 'char_ids': np.array(char_ids),
'position_ids': np.array(position_ids).astype(np.int64), 'position_ids': np.array(position_ids),
} }
return outputs return outputs

@ -34,7 +34,7 @@ from paddlespeech.t2s.frontend.g2pw.utils import load_config
from paddlespeech.t2s.frontend.zh_normalization.char_convert import tranditional_to_simplified from paddlespeech.t2s.frontend.zh_normalization.char_convert import tranditional_to_simplified
from paddlespeech.utils.env import MODEL_HOME from paddlespeech.utils.env import MODEL_HOME
model_version = '1.1' model_version = '1.0'
def predict(session, onnx_input, labels): def predict(session, onnx_input, labels):

@ -61,11 +61,7 @@ class MixFrontend():
return False return False
def is_end(self, before_char, after_char) -> bool: def is_end(self, before_char, after_char) -> bool:
flag = 0 if ((self.is_alphabet(before_char) or before_char == " ") and (self.is_alphabet(after_char) or after_char == " ")):
for char in (before_char, after_char):
if self.is_alphabet(char) or char == " ":
flag += 1
if flag == 2:
return True return True
else: else:
return False return False
@ -90,11 +86,10 @@ class MixFrontend():
if point_index == 0 or point_index == len(text) - 1: if point_index == 0 or point_index == len(text) - 1:
new_text = text new_text = text
else: else:
if not self.is_end(text[point_index - 1], text[point_index + if not self.is_end(text[point_index - 1], text[point_index + 1]):
1]):
new_text = text new_text = text
else: else:
new_text = text[:point_index] + "" + text[point_index + 1:] new_text = text[: point_index] + "" + text[point_index + 1:]
elif len(point_indexs) == 2: elif len(point_indexs) == 2:
first_index = point_indexs[0] first_index = point_indexs[0]
@ -102,8 +97,7 @@ class MixFrontend():
# first # first
if first_index != 0: if first_index != 0:
if not self.is_end(text[first_index - 1], text[first_index + if not self.is_end(text[first_index - 1], text[first_index + 1]):
1]):
new_text += (text[:first_index] + ".") new_text += (text[:first_index] + ".")
else: else:
new_text += (text[:first_index] + "") new_text += (text[:first_index] + "")
@ -112,10 +106,9 @@ class MixFrontend():
# last # last
if end_index != len(text) - 1: if end_index != len(text) - 1:
if not self.is_end(text[end_index - 1], text[end_index + 1]): if not self.is_end(text[end_index - 1], text[end_index + 1]):
new_text += text[point_indexs[-2] + 1:] new_text += text[point_indexs[-2] + 1 : ]
else: else:
new_text += (text[point_indexs[-2] + 1:end_index] + "" + new_text += (text[point_indexs[-2] + 1 : end_index] + "" + text[end_index + 1 : ])
text[end_index + 1:])
else: else:
new_text += "." new_text += "."
@ -124,8 +117,7 @@ class MixFrontend():
end_index = point_indexs[-1] end_index = point_indexs[-1]
# first # first
if first_index != 0: if first_index != 0:
if not self.is_end(text[first_index - 1], text[first_index + if not self.is_end(text[first_index - 1], text[first_index + 1]):
1]):
new_text += (text[:first_index] + ".") new_text += (text[:first_index] + ".")
else: else:
new_text += (text[:first_index] + "") new_text += (text[:first_index] + "")
@ -134,20 +126,16 @@ class MixFrontend():
# middle # middle
for j in range(1, len(point_indexs) - 1): for j in range(1, len(point_indexs) - 1):
point_index = point_indexs[j] point_index = point_indexs[j]
if not self.is_end(text[point_index - 1], text[point_index + if not self.is_end(text[point_index - 1], text[point_index + 1]):
1]): new_text += (text[point_indexs[j-1] + 1 : point_index] + ".")
new_text += (
text[point_indexs[j - 1] + 1:point_index] + ".")
else: else:
new_text += ( new_text += (text[point_indexs[j-1] + 1 : point_index] + "")
text[point_indexs[j - 1] + 1:point_index] + "")
# last # last
if end_index != len(text) - 1: if end_index != len(text) - 1:
if not self.is_end(text[end_index - 1], text[end_index + 1]): if not self.is_end(text[end_index - 1], text[end_index + 1]):
new_text += text[point_indexs[-2] + 1:] new_text += text[point_indexs[-2] + 1 : ]
else: else:
new_text += (text[point_indexs[-2] + 1:end_index] + "" + new_text += (text[point_indexs[-2] + 1 : end_index] + "" + text[end_index + 1 : ])
text[end_index + 1:])
else: else:
new_text += "." new_text += "."
@ -236,7 +224,7 @@ class MixFrontend():
def get_input_ids(self, def get_input_ids(self,
sentence: str, sentence: str,
merge_sentences: bool=False, merge_sentences: bool=True,
get_tone_ids: bool=False, get_tone_ids: bool=False,
add_sp: bool=True, add_sp: bool=True,
to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]: to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
@ -244,29 +232,28 @@ class MixFrontend():
sentences = self._split(sentence) sentences = self._split(sentence)
phones_list = [] phones_list = []
result = {} result = {}
for text in sentences: for text in sentences:
phones_seg = [] phones_seg = []
segments = self._distinguish(text) segments = self._distinguish(text)
for seg in segments: for seg in segments:
content = seg[0] content = seg[0]
lang = seg[1] lang = seg[1]
if content != '': if lang == "zh":
if lang == "en":
input_ids = self.en_frontend.get_input_ids(
content, merge_sentences=True, to_tensor=to_tensor)
else:
input_ids = self.zh_frontend.get_input_ids( input_ids = self.zh_frontend.get_input_ids(
content, content,
merge_sentences=True, merge_sentences=True,
get_tone_ids=get_tone_ids, get_tone_ids=get_tone_ids,
to_tensor=to_tensor) to_tensor=to_tensor)
elif lang == "en":
input_ids = self.en_frontend.get_input_ids(
content, merge_sentences=True, to_tensor=to_tensor)
phones_seg.append(input_ids["phone_ids"][0]) phones_seg.append(input_ids["phone_ids"][0])
if add_sp: if add_sp:
phones_seg.append(self.sp_id_tensor) phones_seg.append(self.sp_id_tensor)
if phones_seg == []:
phones_seg.append(self.sp_id_tensor)
phones = paddle.concat(phones_seg) phones = paddle.concat(phones_seg)
phones_list.append(phones) phones_list.append(phones)

@ -42,8 +42,3 @@ polyphonic:
咖喱: ['ga1','li5'] 咖喱: ['ga1','li5']
时分: ['shi2','fen1'] 时分: ['shi2','fen1']
蚌埠: ['beng4','bu4'] 蚌埠: ['beng4','bu4']
驯服: ['xun4','fu2']
幸免于难: ['xing4','mian3','yu2','nan4']
恶行: ['e4','xing2']
: ['ai4']

@ -42,7 +42,7 @@ class ToneSandhi():
'木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾', '木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾',
'收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼', '收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼',
'抬举', '护士', '折腾', '扫帚', '打量', '打算', '打扮', '打听', '打发', '扎实', '扁担', '抬举', '护士', '折腾', '扫帚', '打量', '打算', '打扮', '打听', '打发', '扎实', '扁担',
'戒指', '懒得', '意识', '意思', '悟性', '怪物', '思量', '怎么', '念头', '念叨', '别人', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头', '念叨',
'快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼', '干事', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼', '干事',
'帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数', '屁股', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数', '屁股',
'尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气', '实在', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气', '实在',
@ -60,7 +60,7 @@ class ToneSandhi():
'邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅', '幸福', '熟悉', '计划', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅', '幸福', '熟悉', '计划',
'扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱', '凤凰', '拖沓', '寒碜', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱', '凤凰', '拖沓', '寒碜',
'糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱', '扫把', '惦记', '戏弄', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱', '扫把', '惦记', '戏弄',
'将军' '将军', '别人'
} }
self.must_not_neural_tone_words = { self.must_not_neural_tone_words = {
'男子', '女子', '分子', '原子', '量子', '莲子', '石子', '瓜子', '电子', '人人', '虎虎', '男子', '女子', '分子', '原子', '量子', '莲子', '石子', '瓜子', '电子', '人人', '虎虎',
@ -84,7 +84,7 @@ class ToneSandhi():
if j - 1 >= 0 and item == word[j - 1] and pos[0] in {"n", "v", "a"}: if j - 1 >= 0 and item == word[j - 1] and pos[0] in {"n", "v", "a"}:
finals[j] = finals[j][:-1] + "5" finals[j] = finals[j][:-1] + "5"
ge_idx = word.find("") ge_idx = word.find("")
if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒滴哩哟喽啰耶喔诶": if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒滴哩哟喽啰耶喔诶":
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
elif len(word) >= 1 and word[-1] in "的地得": elif len(word) >= 1 and word[-1] in "的地得":
finals[-1] = finals[-1][:-1] + "5" finals[-1] = finals[-1][:-1] + "5"
@ -169,7 +169,6 @@ class ToneSandhi():
return new_word_list return new_word_list
def _three_sandhi(self, word: str, finals: List[str]) -> List[str]: def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
if len(word) == 2 and self._all_tone_three(finals): if len(word) == 2 and self._all_tone_three(finals):
finals[0] = finals[0][:-1] + "2" finals[0] = finals[0][:-1] + "2"
elif len(word) == 3: elif len(word) == 3:
@ -347,7 +346,6 @@ class ToneSandhi():
def modified_tone(self, word: str, pos: str, def modified_tone(self, word: str, pos: str,
finals: List[str]) -> List[str]: finals: List[str]) -> List[str]:
finals = self._bu_sandhi(word, finals) finals = self._bu_sandhi(word, finals)
finals = self._yi_sandhi(word, finals) finals = self._yi_sandhi(word, finals)
finals = self._neural_sandhi(word, pos, finals) finals = self._neural_sandhi(word, pos, finals)

@ -28,7 +28,7 @@ UNITS = OrderedDict({
8: '亿', 8: '亿',
}) })
COM_QUANTIFIERS = '(封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)' COM_QUANTIFIERS = '(所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
# 分数表达式 # 分数表达式
RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)') RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')

@ -13,3 +13,4 @@
# limitations under the License. # limitations under the License.
from .ernie_sat import * from .ernie_sat import *
from .ernie_sat_updater import * from .ernie_sat_updater import *
from .mlm import *

@ -389,7 +389,7 @@ class MLM(nn.Layer):
speech_seg_pos: paddle.Tensor, speech_seg_pos: paddle.Tensor,
text_seg_pos: paddle.Tensor, text_seg_pos: paddle.Tensor,
span_bdy: List[int], span_bdy: List[int],
use_teacher_forcing: bool=True, ) -> List[paddle.Tensor]: use_teacher_forcing: bool=False, ) -> List[paddle.Tensor]:
''' '''
Args: Args:
speech (paddle.Tensor): input speech (1, Tmax, D). speech (paddle.Tensor): input speech (1, Tmax, D).
@ -657,7 +657,7 @@ class ErnieSAT(nn.Layer):
speech_seg_pos: paddle.Tensor, speech_seg_pos: paddle.Tensor,
text_seg_pos: paddle.Tensor, text_seg_pos: paddle.Tensor,
span_bdy: List[int], span_bdy: List[int],
use_teacher_forcing: bool=True, ) -> Dict[str, paddle.Tensor]: use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
return self.model.inference( return self.model.inference(
speech=speech, speech=speech,
text=text, text=text,

@ -0,0 +1,579 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import Dict
from typing import List
from typing import Optional
import paddle
import yaml
from paddle import nn
from yacs.config import CfgNode
from paddlespeech.t2s.modules.activation import get_activation
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
from paddlespeech.t2s.modules.layer_norm import LayerNorm
from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention
from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear
from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
from paddlespeech.t2s.modules.transformer.repeat import repeat
from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling
# MLM -> Mask Language Model
class mySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._sub_layers.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class MaskInputLayer(nn.Layer):
def __init__(self, out_features: int) -> None:
super().__init__()
self.mask_feature = paddle.create_parameter(
shape=(1, 1, out_features),
dtype=paddle.float32,
default_initializer=paddle.nn.initializer.Assign(
paddle.normal(shape=(1, 1, out_features))))
def forward(self, input: paddle.Tensor,
masked_pos: paddle.Tensor=None) -> paddle.Tensor:
masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
return masked_input
class MLMEncoder(nn.Layer):
"""Conformer encoder module.
Args:
idim (int): Input dimension.
attention_dim (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
attention_dropout_rate (float): Dropout rate in attention.
input_layer (Union[str, paddle.nn.Layer]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
macaron_style (bool): Whether to use macaron style for positionwise layer.
pos_enc_layer_type (str): Encoder positional encoding layer type.
selfattention_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
"""
def __init__(self,
idim: int,
vocab_size: int=0,
pre_speech_layer: int=0,
attention_dim: int=256,
attention_heads: int=4,
linear_units: int=2048,
num_blocks: int=6,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.1,
attention_dropout_rate: float=0.0,
input_layer: str="conv2d",
normalize_before: bool=True,
concat_after: bool=False,
positionwise_layer_type: str="linear",
positionwise_conv_kernel_size: int=1,
macaron_style: bool=False,
pos_enc_layer_type: str="abs_pos",
selfattention_layer_type: str="selfattn",
activation_type: str="swish",
use_cnn_module: bool=False,
zero_triu: bool=False,
cnn_module_kernel: int=31,
padding_idx: int=-1,
stochastic_depth_rate: float=0.0,
text_masking: bool=False):
"""Construct an Encoder object."""
super().__init__()
self._output_size = attention_dim
self.text_masking = text_masking
if self.text_masking:
self.text_masking_layer = MaskInputLayer(attention_dim)
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "legacy_rel_pos":
pos_enc_class = LegacyRelPositionalEncoding
assert selfattention_layer_type == "legacy_rel_selfattn"
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
self.conv_subsampling_factor = 1
if input_layer == "linear":
self.embed = nn.Sequential(
nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim),
nn.Dropout(dropout_rate),
nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate), )
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate), )
self.conv_subsampling_factor = 4
elif input_layer == "embed":
self.embed = nn.Sequential(
nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate), )
elif input_layer == "mlm":
self.segment_emb = None
self.speech_embed = mySequential(
MaskInputLayer(idim),
nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim),
nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate))
self.text_embed = nn.Sequential(
nn.Embedding(
vocab_size, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate), )
elif input_layer == "sega_mlm":
self.segment_emb = nn.Embedding(
500, attention_dim, padding_idx=padding_idx)
self.speech_embed = mySequential(
MaskInputLayer(idim),
nn.Linear(idim, attention_dim),
nn.LayerNorm(attention_dim),
nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate))
self.text_embed = nn.Sequential(
nn.Embedding(
vocab_size, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate), )
elif isinstance(input_layer, nn.Layer):
self.embed = nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate), )
elif input_layer is None:
self.embed = nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate))
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
# self-attention module definition
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, )
elif selfattention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, )
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (attention_heads, attention_dim,
attention_dropout_rate, zero_triu, )
else:
raise ValueError("unknown encoder_attn_layer: " +
selfattention_layer_type)
# feed-forward module definition
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (attention_dim, linear_units,
dropout_rate, activation, )
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (attention_dim, linear_units,
positionwise_conv_kernel_size,
dropout_rate, )
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (attention_dim, linear_units,
positionwise_conv_kernel_size,
dropout_rate, )
else:
raise NotImplementedError("Support only linear or conv1d.")
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
attention_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate * float(1 + lnum) / num_blocks, ), )
self.pre_speech_layer = pre_speech_layer
self.pre_speech_encoders = repeat(
self.pre_speech_layer,
lambda lnum: EncoderLayer(
attention_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ),
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
def forward(self,
speech: paddle.Tensor,
text: paddle.Tensor,
masked_pos: paddle.Tensor,
speech_mask: paddle.Tensor=None,
text_mask: paddle.Tensor=None,
speech_seg_pos: paddle.Tensor=None,
text_seg_pos: paddle.Tensor=None):
"""Encode input sequence.
"""
if masked_pos is not None:
speech = self.speech_embed(speech, masked_pos)
else:
speech = self.speech_embed(speech)
if text is not None:
text = self.text_embed(text)
if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
speech_seg_emb = self.segment_emb(speech_seg_pos)
text_seg_emb = self.segment_emb(text_seg_pos)
text = (text[0] + text_seg_emb, text[1])
speech = (speech[0] + speech_seg_emb, speech[1])
if self.pre_speech_encoders:
speech, _ = self.pre_speech_encoders(speech, speech_mask)
if text is not None:
xs = paddle.concat([speech[0], text[0]], axis=1)
xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1)
masks = paddle.concat([speech_mask, text_mask], axis=-1)
else:
xs = speech[0]
xs_pos_emb = speech[1]
masks = speech_mask
xs, masks = self.encoders((xs, xs_pos_emb), masks)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks
class MLMDecoder(MLMEncoder):
def forward(self, xs: paddle.Tensor, masks: paddle.Tensor):
"""Encode input sequence.
Args:
xs (paddle.Tensor): Input tensor (#batch, time, idim).
masks (paddle.Tensor): Mask tensor (#batch, time).
Returns:
paddle.Tensor: Output tensor (#batch, time, attention_dim).
paddle.Tensor: Mask tensor (#batch, time).
"""
xs = self.embed(xs)
xs, masks = self.encoders(xs, masks)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks
# encoder and decoder is nn.Layer, not str
class MLM(nn.Layer):
def __init__(self,
odim: int,
encoder: nn.Layer,
decoder: Optional[nn.Layer],
postnet_layers: int=0,
postnet_chans: int=0,
postnet_filts: int=0,
text_masking: bool=False):
super().__init__()
self.odim = odim
self.encoder = encoder
self.decoder = decoder
self.vocab_size = encoder.text_embed[0]._num_embeddings
if self.decoder is None or not (hasattr(self.decoder,
'output_layer') and
self.decoder.output_layer is not None):
self.sfc = nn.Linear(self.encoder._output_size, odim)
else:
self.sfc = None
if text_masking:
self.text_sfc = nn.Linear(
self.encoder.text_embed[0]._embedding_dim,
self.vocab_size,
weight_attr=self.encoder.text_embed[0]._weight_attr)
else:
self.text_sfc = None
self.postnet = (None if postnet_layers == 0 else Postnet(
idim=self.encoder._output_size,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=True,
dropout_rate=0.5, ))
def inference(
self,
speech: paddle.Tensor,
text: paddle.Tensor,
masked_pos: paddle.Tensor,
speech_mask: paddle.Tensor,
text_mask: paddle.Tensor,
speech_seg_pos: paddle.Tensor,
text_seg_pos: paddle.Tensor,
span_bdy: List[int],
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
'''
Args:
speech (paddle.Tensor): input speech (1, Tmax, D).
text (paddle.Tensor): input text (1, Tmax2).
masked_pos (paddle.Tensor): masked position of input speech (1, Tmax)
speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax).
text_mask (paddle.Tensor): mask of text (1, 1, Tmax2).
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax).
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2).
span_bdy (List[int]): masked mel boundary of input speech (2,)
use_teacher_forcing (bool): whether to use teacher forcing
Returns:
List[Tensor]:
eg:
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
'''
z_cache = None
if use_teacher_forcing:
before_outs, zs, *_ = self.forward(
speech=speech,
text=text,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos)
if zs is None:
zs = before_outs
speech = speech.squeeze(0)
outs = [speech[:span_bdy[0]]]
outs += [zs[0][span_bdy[0]:span_bdy[1]]]
outs += [speech[span_bdy[1]:]]
return outs
return None
class MLMEncAsDecoder(MLM):
def forward(self,
speech: paddle.Tensor,
text: paddle.Tensor,
masked_pos: paddle.Tensor,
speech_mask: paddle.Tensor,
text_mask: paddle.Tensor,
speech_seg_pos: paddle.Tensor,
text_seg_pos: paddle.Tensor):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, h_masks = self.encoder(
speech=speech,
text=text,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos)
if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks)
else:
zs = encoder_out
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
if self.sfc is not None:
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else:
before_outs = speech_hidden_states
if self.postnet is not None:
after_outs = before_outs + paddle.transpose(
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
[0, 2, 1])
else:
after_outs = None
return before_outs, after_outs, None
class MLMDualMaksing(MLM):
def forward(self,
speech: paddle.Tensor,
text: paddle.Tensor,
masked_pos: paddle.Tensor,
speech_mask: paddle.Tensor,
text_mask: paddle.Tensor,
speech_seg_pos: paddle.Tensor,
text_seg_pos: paddle.Tensor):
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
encoder_out, h_masks = self.encoder(
speech=speech,
text=text,
masked_pos=masked_pos,
speech_mask=speech_mask,
text_mask=text_mask,
speech_seg_pos=speech_seg_pos,
text_seg_pos=text_seg_pos)
if self.decoder is not None:
zs, _ = self.decoder(encoder_out, h_masks)
else:
zs = encoder_out
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
if self.text_sfc:
text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :]
text_outs = paddle.reshape(
self.text_sfc(text_hiddent_states),
(paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
if self.sfc is not None:
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
else:
before_outs = speech_hidden_states
if self.postnet is not None:
after_outs = before_outs + paddle.transpose(
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
[0, 2, 1])
else:
after_outs = None
return before_outs, after_outs, text_outs
def build_model_from_file(config_file, model_file):
state_dict = paddle.load(model_file)
model_class = MLMDualMaksing if 'conformer_combine_vctk_aishell3_dual_masking' in config_file \
else MLMEncAsDecoder
# 构建模型
with open(config_file) as f:
conf = CfgNode(yaml.safe_load(f))
model = build_model(conf, model_class)
model.set_state_dict(state_dict)
return model, conf
# select encoder and decoder here
def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM:
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
odim = 80
# Encoder
encoder_class = MLMEncoder
if 'text_masking' in args.model_conf.keys() and args.model_conf[
'text_masking']:
args.encoder_conf['text_masking'] = True
else:
args.encoder_conf['text_masking'] = False
encoder = encoder_class(
args.input_size, vocab_size=vocab_size, **args.encoder_conf)
# Decoder
if args.decoder != 'no_decoder':
decoder_class = MLMDecoder
decoder = decoder_class(
idim=0,
input_layer=None,
**args.decoder_conf, )
else:
decoder = None
# Build model
model = model_class(
odim=odim,
encoder=encoder,
decoder=decoder,
**args.model_conf, )
# Initialize
if args.init is not None:
initialize(model, args.init)
return model

@ -522,82 +522,6 @@ class VITSGenerator(nn.Layer):
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1) return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
def voice_conversion(
self,
feats: paddle.Tensor=None,
feats_lengths: paddle.Tensor=None,
sids_src: Optional[paddle.Tensor]=None,
sids_tgt: Optional[paddle.Tensor]=None,
spembs_src: Optional[paddle.Tensor]=None,
spembs_tgt: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None, ) -> paddle.Tensor:
"""Run voice conversion.
Args:
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
feats_lengths (Tensor): Feature length tensor (B,).
sids_src (Optional[Tensor]): Speaker index tensor of source feature (B,) or (B, 1).
sids_tgt (Optional[Tensor]): Speaker index tensor of target feature (B,) or (B, 1).
spembs_src (Optional[Tensor]): Speaker embedding tensor of source feature (B, spk_embed_dim).
spembs_tgt (Optional[Tensor]): Speaker embedding tensor of target feature (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Tensor: Generated waveform tensor (B, T_wav).
"""
# encoder
g_src = None
g_tgt = None
if self.spks is not None:
# (B, global_channels, 1)
g_src = self.global_emb(
paddle.reshape(sids_src, [-1])).unsqueeze(-1)
g_tgt = self.global_emb(
paddle.reshape(sids_tgt, [-1])).unsqueeze(-1)
if self.spk_embed_dim is not None:
# (B, global_channels, 1)
g_src_ = self.spemb_proj(
F.normalize(spembs_src.unsqueeze(0))).unsqueeze(-1)
if g_src is None:
g_src = g_src_
else:
g_src = g_src + g_src_
# (B, global_channels, 1)
g_tgt_ = self.spemb_proj(
F.normalize(spembs_tgt.unsqueeze(0))).unsqueeze(-1)
if g_tgt is None:
g_tgt = g_tgt_
else:
g_tgt = g_tgt + g_tgt_
if self.langs is not None:
# (B, global_channels, 1)
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
if g_src is None:
g_src = g_
else:
g_src = g_src + g_
if g_tgt is None:
g_tgt = g_
else:
g_tgt = g_tgt + g_
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(
feats, feats_lengths, g=g_src)
# forward flow
# (B, H, T_feats)
z_p = self.flow(z, y_mask, g=g_src)
# decoder
z_hat = self.flow(z_p, y_mask, g=g_tgt, inverse=True)
wav = self.decoder(z_hat * y_mask, g=g_tgt)
return wav.squeeze(1)
def _generate_path(self, dur: paddle.Tensor, def _generate_path(self, dur: paddle.Tensor,
mask: paddle.Tensor) -> paddle.Tensor: mask: paddle.Tensor) -> paddle.Tensor:
"""Generate path a.k.a. monotonic attention. """Generate path a.k.a. monotonic attention.

@ -381,7 +381,7 @@ class VITS(nn.Layer):
if use_teacher_forcing: if use_teacher_forcing:
assert feats is not None assert feats is not None
feats = feats[None].transpose([0, 2, 1]) feats = feats[None].transpose([0, 2, 1])
feats_lengths = paddle.to_tensor(paddle.shape(feats)[2]) feats_lengths = paddle.to_tensor([paddle.shape(feats)[2]])
wav, att_w, dur = self.generator.inference( wav, att_w, dur = self.generator.inference(
text=text, text=text,
text_lengths=text_lengths, text_lengths=text_lengths,
@ -406,43 +406,3 @@ class VITS(nn.Layer):
max_len=max_len, ) max_len=max_len, )
return dict( return dict(
wav=paddle.reshape(wav, [-1]), att_w=att_w[0], duration=dur[0]) wav=paddle.reshape(wav, [-1]), att_w=att_w[0], duration=dur[0])
def voice_conversion(
self,
feats: paddle.Tensor,
sids_src: Optional[paddle.Tensor]=None,
sids_tgt: Optional[paddle.Tensor]=None,
spembs_src: Optional[paddle.Tensor]=None,
spembs_tgt: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None, ) -> paddle.Tensor:
"""Run voice conversion.
Args:
feats (Tensor): Feature tensor (T_feats, aux_channels).
sids_src (Optional[Tensor]): Speaker index tensor of source feature (1,).
sids_tgt (Optional[Tensor]): Speaker index tensor of target feature (1,).
spembs_src (Optional[Tensor]): Speaker embedding tensor of source feature (spk_embed_dim,).
spembs_tgt (Optional[Tensor]): Speaker embedding tensor of target feature (spk_embed_dim,).
lids (Optional[Tensor]): Language index tensor (1,).
Returns:
Dict[str, Tensor]:
* wav (Tensor): Generated waveform tensor (T_wav,).
"""
assert feats is not None
feats = feats[None].transpose([0, 2, 1])
feats_lengths = paddle.to_tensor(paddle.shape(feats)[2])
sids_none = sids_src is None and sids_tgt is None
spembs_none = spembs_src is None and spembs_tgt is None
assert not sids_none or not spembs_none
wav = self.generator.voice_conversion(
feats,
feats_lengths,
sids_src,
sids_tgt,
spembs_src,
spembs_tgt,
lids, )
return dict(wav=paddle.reshape(wav, [-1]))

@ -111,8 +111,6 @@ class VITSUpdater(StandardUpdater):
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
feats=batch["feats"], feats=batch["feats"],
feats_lengths=batch["feats_lengths"], feats_lengths=batch["feats_lengths"],
sids=batch.get("spk_id", None),
spembs=batch.get("spk_emb", None),
forward_generator=turn == "generator") forward_generator=turn == "generator")
# Generator # Generator
if turn == "generator": if turn == "generator":
@ -270,8 +268,6 @@ class VITSEvaluator(StandardEvaluator):
text_lengths=batch["text_lengths"], text_lengths=batch["text_lengths"],
feats=batch["feats"], feats=batch["feats"],
feats_lengths=batch["feats_lengths"], feats_lengths=batch["feats_lengths"],
sids=batch.get("spk_id", None),
spembs=batch.get("spk_emb", None),
forward_generator=turn == "generator") forward_generator=turn == "generator")
# Generator # Generator
if turn == "generator": if turn == "generator":

@ -24,11 +24,10 @@ from paddle.nn import Layer
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from timer import timer from timer import timer
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updater import UpdaterBase from paddlespeech.t2s.training.updater import UpdaterBase
from paddlespeech.t2s.training.updater import UpdaterState from paddlespeech.t2s.training.updater import UpdaterState
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
class StandardUpdater(UpdaterBase): class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but """An example of over-simplification. Things may not be that simple, but

Loading…
Cancel
Save