From c7a7b113c856a455f92b70b256f782d25133bc8d Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Fri, 24 Jun 2022 05:01:44 +0000 Subject: [PATCH] support multi-gpu training with webdataset --- examples/wenetspeech/asr1/conf/conformer.yaml | 35 +- paddlespeech/audio/stream_data/__init__.py | 3 +- paddlespeech/audio/stream_data/filters.py | 39 +- paddlespeech/audio/stream_data/pipeline.py | 6 + paddlespeech/audio/stream_data/shardlists.py | 2 + paddlespeech/audio/utils/log.py | 3 +- paddlespeech/audio/utils/tensor_utils.py | 3 - paddlespeech/s2t/exps/u2/model.py | 269 ++++++---- paddlespeech/s2t/io/dataloader.py | 87 ++++ paddlespeech/s2t/io/reader.py | 2 +- paddlespeech/s2t/transform/__init__.py | 13 - paddlespeech/s2t/transform/add_deltas.py | 54 -- .../s2t/transform/channel_selector.py | 57 --- paddlespeech/s2t/transform/cmvn.py | 201 -------- paddlespeech/s2t/transform/functional.py | 86 ---- paddlespeech/s2t/transform/perturb.py | 471 ----------------- paddlespeech/s2t/transform/spec_augment.py | 214 -------- paddlespeech/s2t/transform/spectrogram.py | 475 ------------------ .../s2t/transform/transform_interface.py | 35 -- paddlespeech/s2t/transform/transformation.py | 158 ------ paddlespeech/s2t/transform/wpe.py | 58 --- 21 files changed, 341 insertions(+), 1930 deletions(-) delete mode 100644 paddlespeech/s2t/transform/__init__.py delete mode 100644 paddlespeech/s2t/transform/add_deltas.py delete mode 100644 paddlespeech/s2t/transform/channel_selector.py delete mode 100644 paddlespeech/s2t/transform/cmvn.py delete mode 100644 paddlespeech/s2t/transform/functional.py delete mode 100644 paddlespeech/s2t/transform/perturb.py delete mode 100644 paddlespeech/s2t/transform/spec_augment.py delete mode 100644 paddlespeech/s2t/transform/spectrogram.py delete mode 100644 paddlespeech/s2t/transform/transform_interface.py delete mode 100644 paddlespeech/s2t/transform/transformation.py delete mode 100644 paddlespeech/s2t/transform/wpe.py diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml index 6c2bbca4..dd4ff0e2 100644 --- a/examples/wenetspeech/asr1/conf/conformer.yaml +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -50,26 +50,41 @@ test_manifest: data/manifest.test ########################################### # Dataloader # ########################################### -vocab_filepath: data/lang_char/vocab.txt +use_stream_data: True unit_type: 'char' +vocab_filepath: data/lang_char/vocab.txt +cmvn_file: data/mean_std.json preprocess_config: conf/preprocess.yaml spm_model_prefix: '' feat_dim: 80 stride_ms: 10.0 window_ms: 25.0 +dither: 0.1 sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs batch_size: 64 +minlen_in: 10 maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced +minlen_out: 0 maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced -minibatches: 0 # for debug -batch_count: auto -batch_bins: 0 -batch_frames_in: 0 -batch_frames_out: 0 -batch_frames_inout: 0 -num_workers: 0 -subsampling_factor: 1 +resample_rate: 16000 +shuffle_size: 10000 +sort_size: 500 +num_workers: 4 +prefetch_factor: 100 +dist_sampler: True num_encs: 1 +augment_conf: + max_w: 80 + w_inplace: True + w_mode: "PIL" + max_f: 30 + num_f_mask: 2 + f_inplace: True + f_replace_with_zero: False + max_t: 40 + num_t_mask: 2 + t_inplace: True + t_replace_with_zero: False ########################################### @@ -78,7 +93,7 @@ num_encs: 1 n_epoch: 240 accum_grad: 16 global_grad_clip: 5.0 -log_interval: 100 +log_interval: 1 checkpoint: kbest_n: 50 latest_n: 5 diff --git a/paddlespeech/audio/stream_data/__init__.py b/paddlespeech/audio/stream_data/__init__.py index fdb3458c..e9706d4e 100644 --- a/paddlespeech/audio/stream_data/__init__.py +++ b/paddlespeech/audio/stream_data/__init__.py @@ -41,7 +41,8 @@ from .filters import ( spec_aug, sort, padding, - cmvn + cmvn, + placeholder, ) from webdataset.handlers import ( ignore_and_continue, diff --git a/paddlespeech/audio/stream_data/filters.py b/paddlespeech/audio/stream_data/filters.py index 3112c954..db3e037a 100644 --- a/paddlespeech/audio/stream_data/filters.py +++ b/paddlespeech/audio/stream_data/filters.py @@ -758,27 +758,44 @@ def _compute_fbank(source, compute_fbank = pipelinefilter(_compute_fbank) -def _spec_aug(source, num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80): +def _spec_aug(source, + max_w=5, + w_inplace=True, + w_mode="PIL", + max_f=30, + num_f_mask=2, + f_inplace=True, + f_replace_with_zero=False, + max_t=40, + num_t_mask=2, + t_inplace=True, + t_replace_with_zero=False,): """ Do spec augmentation Inplace operation Args: source: Iterable[{fname, feat, label}] - num_t_mask: number of time mask to apply + max_w: max width of time warp + w_inplace: whether to inplace the original data while time warping + w_mode: time warp mode + max_f: max width of freq mask num_f_mask: number of freq mask to apply + f_inplace: whether to inplace the original data while frequency masking + f_replace_with_zero: use zero to mask max_t: max width of time mask - max_f: max width of freq mask - max_w: max width of time warp - + num_t_mask: number of time mask to apply + t_inplace: whether to inplace the original data while time masking + t_replace_with_zero: use zero to mask + Returns Iterable[{fname, feat, label}] """ for sample in source: x = sample['feat'] x = x.numpy() - x = time_warp(x, max_time_warp=max_w, inplace = True, mode= "PIL") - x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = True, replace_with_zero = False) - x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = True, replace_with_zero = False) + x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode) + x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero) + x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero) sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) yield sample @@ -910,3 +927,9 @@ def _cmvn(source, cmvn_file): label_lengths) cmvn = pipelinefilter(_cmvn) + +def _placeholder(source): + for data in source: + yield data + +placeholder = pipelinefilter(_placeholder) \ No newline at end of file diff --git a/paddlespeech/audio/stream_data/pipeline.py b/paddlespeech/audio/stream_data/pipeline.py index b672773b..e738083f 100644 --- a/paddlespeech/audio/stream_data/pipeline.py +++ b/paddlespeech/audio/stream_data/pipeline.py @@ -89,6 +89,12 @@ class DataPipeline(IterableDataset, PipelineStage): def append(self, f): """Append a pipeline stage (modifies the object).""" self.pipeline.append(f) + return self + + def append_list(self, *args): + for arg in args: + self.pipeline.append(arg) + return self def compose(self, *args): """Append a pipeline stage to a copy of the pipeline and returns the copy.""" diff --git a/paddlespeech/audio/stream_data/shardlists.py b/paddlespeech/audio/stream_data/shardlists.py index 503bfe57..3d1801cc 100644 --- a/paddlespeech/audio/stream_data/shardlists.py +++ b/paddlespeech/audio/stream_data/shardlists.py @@ -24,6 +24,8 @@ from .filters import pipelinefilter from .paddle_utils import IterableDataset +from ..utils.log import Logger +logger = Logger(__name__) def expand_urls(urls): if isinstance(urls, str): urllist = urls.split("::") diff --git a/paddlespeech/audio/utils/log.py b/paddlespeech/audio/utils/log.py index 5656b286..0a25bbd5 100644 --- a/paddlespeech/audio/utils/log.py +++ b/paddlespeech/audio/utils/log.py @@ -65,6 +65,7 @@ class Logger(object): def __init__(self, name: str=None): name = 'PaddleAudio' if not name else name + self.name = name self.logger = logging.getLogger(name) for key, conf in log_config.items(): @@ -101,7 +102,7 @@ class Logger(object): if not self.is_enable: return - self.logger.log(log_level, msg) + self.logger.log(log_level, self.name + " | " + msg) @contextlib.contextmanager def use_terminator(self, terminator: str): diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py index bae473ec..16f60810 100644 --- a/paddlespeech/audio/utils/tensor_utils.py +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -93,9 +93,6 @@ def pad_sequence(sequences: List[paddle.Tensor], for i, tensor in enumerate(sequences): length = tensor.shape[0] # use index notation to prevent duplicate references to the tensor - logger.info( - f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}" - ) if batch_first: # TODO (Hui Zhang): set_value op not supprot `end==start` # TODO (Hui Zhang): set_value op not support int16 diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index efcc9629..d6c68f96 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -26,6 +26,7 @@ from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader +from paddlespeech.s2t.io.dataloader import StreamDataLoader from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -106,7 +107,8 @@ class U2Trainer(Trainer): @paddle.no_grad() def valid(self): self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -132,7 +134,7 @@ class U2Trainer(Trainer): msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + #msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -152,7 +154,8 @@ class U2Trainer(Trainer): self.before_train() - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -170,7 +173,8 @@ class U2Trainer(Trainer): self.train_batch(batch_index, batch, msg) self.after_train_batch() report('iter', batch_index + 1) - report('total', len(self.train_loader)) + if not self.use_streamdata: + report('total', len(self.train_loader)) report('reader_cost', dataload_time) observation['batch_cost'] = observation[ 'reader_cost'] + observation['step_cost'] @@ -218,92 +222,188 @@ class U2Trainer(Trainer): def setup_dataloader(self): config = self.config.clone() - + self.use_streamdata = config.get("use_stream_data", False) if self.train: # train/valid dataset, return token ids - self.train_loader = BatchDataLoader( - json_file=config.train_manifest, - train_mode=True, - sortagrad=config.sortagrad, - batch_size=config.batch_size, - maxlen_in=config.maxlen_in, - maxlen_out=config.maxlen_out, - minibatches=config.minibatches, - mini_batch_size=self.args.ngpu, - batch_count=config.batch_count, - batch_bins=config.batch_bins, - batch_frames_in=config.batch_frames_in, - batch_frames_out=config.batch_frames_out, - batch_frames_inout=config.batch_frames_inout, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) - - self.valid_loader = BatchDataLoader( - json_file=config.dev_manifest, - train_mode=False, - sortagrad=False, - batch_size=config.batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=self.args.ngpu, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=config.num_workers, - subsampling_factor=1, - num_encs=1, - dist_sampler=config.get('dist_sampler', False), - shortest_first=False) + if self.use_streamdata: + self.train_loader = StreamDataLoader( + manifest_file=config.train_manifest, + train_mode=True, + unit_type=config.unit_type, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=config.dither, + minlen_in=config.minlen_in, + maxlen_in=config.maxlen_in, + minlen_out=config.minlen_out, + maxlen_out=config.maxlen_out, + resample_rate=config.resample_rate, + augment_conf=config.augment_conf, # dict + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.get('dist_sampler', False), + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + self.valid_loader = StreamDataLoader( + manifest_file=config.dev_manifest, + train_mode=False, + unit_type=config.unit_type, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=config.dither, + minlen_in=config.minlen_in, + maxlen_in=config.maxlen_in, + minlen_out=config.minlen_out, + maxlen_out=config.maxlen_out, + resample_rate=config.resample_rate, + augment_conf=config.augment_conf, # dict + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.get('dist_sampler', False), + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + else: + self.train_loader = BatchDataLoader( + json_file=config.train_manifest, + train_mode=True, + sortagrad=config.sortagrad, + batch_size=config.batch_size, + maxlen_in=config.maxlen_in, + maxlen_out=config.maxlen_out, + minibatches=config.minibatches, + mini_batch_size=self.args.ngpu, + batch_count=config.batch_count, + batch_bins=config.batch_bins, + batch_frames_in=config.batch_frames_in, + batch_frames_out=config.batch_frames_out, + batch_frames_inout=config.batch_frames_inout, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=1, + num_encs=1, + dist_sampler=config.get('dist_sampler', False), + shortest_first=False) + + self.valid_loader = BatchDataLoader( + json_file=config.dev_manifest, + train_mode=False, + sortagrad=False, + batch_size=config.batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=self.args.ngpu, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=config.num_workers, + subsampling_factor=1, + num_encs=1, + dist_sampler=config.get('dist_sampler', False), + shortest_first=False) logger.info("Setup train/valid Dataloader!") else: decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) # test dataset, return raw text - self.test_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) - - self.align_loader = BatchDataLoader( - json_file=config.test_manifest, - train_mode=False, - sortagrad=False, - batch_size=decode_batch_size, - maxlen_in=float('inf'), - maxlen_out=float('inf'), - minibatches=0, - mini_batch_size=1, - batch_count='auto', - batch_bins=0, - batch_frames_in=0, - batch_frames_out=0, - batch_frames_inout=0, - preprocess_conf=config.preprocess_config, - n_iter_processes=1, - subsampling_factor=1, - num_encs=1) + if self.use_streamdata: + self.test_loader = StreamDataLoader( + manifest_file=config.test_manifest, + train_mode=False, + unit_type=config.unit_type, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=0.0, + minlen_in=0.0, + maxlen_in=float('inf'), + minlen_out=0, + maxlen_out=float('inf'), + resample_rate=config.resample_rate, + augment_conf=config.augment_conf, # dict + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.get('dist_sampler', False), + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + self.align_loader = StreamDataLoader( + manifest_file=config.test_manifest, + train_mode=False, + unit_type=config.unit_type, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=0.0, + minlen_in=0.0, + maxlen_in=float('inf'), + minlen_out=0, + maxlen_out=float('inf'), + resample_rate=config.resample_rate, + augment_conf=config.augment_conf, # dict + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.get('dist_sampler', False), + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, + ) + else: + self.test_loader = BatchDataLoader( + json_file=config.test_manifest, + train_mode=False, + sortagrad=False, + batch_size=decode_batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=1, + subsampling_factor=1, + num_encs=1) + + self.align_loader = BatchDataLoader( + json_file=config.test_manifest, + train_mode=False, + sortagrad=False, + batch_size=decode_batch_size, + maxlen_in=float('inf'), + maxlen_out=float('inf'), + minibatches=0, + mini_batch_size=1, + batch_count='auto', + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + preprocess_conf=config.preprocess_config, + n_iter_processes=1, + subsampling_factor=1, + num_encs=1) logger.info("Setup test/align Dataloader!") def setup_model(self): @@ -452,7 +552,8 @@ class U2Tester(U2Trainer): def test(self): assert self.args.result_file self.model.eval() - logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") + if not self.use_streamdata: + logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") stride_ms = self.config.stride_ms error_rate_type = None diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 55aa13ff..c27969f0 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -28,6 +28,9 @@ from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.utils.log import Log +import paddlespeech.audio.stream_data as stream_data +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer + __all__ = ["BatchDataLoader"] logger = Log(__name__).getlog() @@ -56,6 +59,90 @@ def batch_collate(x): """ return x[0] +class StreamDataLoader(): + def __init__(self, + manifest_file: str, + train_mode: bool, + unit_type: str='char', + batch_size: int=0, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + minlen_in: float=0.0, + maxlen_in: float=float('inf'), + minlen_out: float=0.0, + maxlen_out: float=float('inf'), + resample_rate: int=16000, + augment_conf: dict=None, + shuffle_size: int=10000, + sort_size: int=1000, + n_iter_processes: int=1, + prefetch_factor: int=2, + dist_sampler: bool=False, + cmvn_file="data/mean_std.json", + vocab_filepath='data/lang_char/vocab.txt'): + self.manifest_file = manifest_file + self.train_model = train_mode + self.batch_size = batch_size + self.prefetch_factor = prefetch_factor + self.dist_sampler = dist_sampler + self.n_iter_processes = n_iter_processes + + text_featurizer = TextFeaturizer(unit_type, vocab_filepath) + symbol_table = text_featurizer.vocab_dict + self.feat_dim = num_mel_bins + self.vocab_size = text_featurizer.vocab_size + + # The list of shard + shardlist = [] + with open(manifest_file, "r") as f: + for line in f.readlines(): + shardlist.append(line.strip()) + + if self.dist_sampler: + base_dataset = stream_data.DataPipeline( + stream_data.SimpleShardList(shardlist), + stream_data.split_by_node, + stream_data.split_by_worker, + stream_data.tarfile_to_samples(stream_data.reraise_exception) + ) + else: + base_dataset = stream_data.DataPipeline( + stream_data.SimpleShardList(shardlist), + stream_data.split_by_worker, + stream_data.tarfile_to_samples(stream_data.reraise_exception) + ) + + self.dataset = base_dataset.append_list( + stream_data.tokenize(symbol_table), + stream_data.data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_in), + stream_data.resample(resample_rate=resample_rate), + stream_data.compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), + stream_data.spec_aug(**augment_conf) if train_mode else stream_data.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) + stream_data.shuffle(shuffle_size), + stream_data.sort(sort_size=sort_size), + stream_data.batched(batch_size), + stream_data.padding(), + stream_data.cmvn(cmvn_file) + ) + self.loader = stream_data.WebLoader( + self.dataset, + num_workers=self.n_iter_processes, + prefetch_factor = self.prefetch_factor, + batch_size=None + ) + + def __iter__(self): + return self.loader.__iter__() + + def __call__(self): + return self.__iter__() + + def __len__(self): + logger.info("Stream dataloader does not support calculate the length of the dataset") + return -1 + class BatchDataLoader(): def __init__(self, diff --git a/paddlespeech/s2t/io/reader.py b/paddlespeech/s2t/io/reader.py index 4e136bdc..5e018bef 100644 --- a/paddlespeech/s2t/io/reader.py +++ b/paddlespeech/s2t/io/reader.py @@ -19,7 +19,7 @@ import numpy as np import soundfile from .utility import feat_type -from paddlespeech.s2t.transform.transformation import Transformation +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.utils.log import Log # from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline as Transformation diff --git a/paddlespeech/s2t/transform/__init__.py b/paddlespeech/s2t/transform/__init__.py deleted file mode 100644 index 185a92b8..00000000 --- a/paddlespeech/s2t/transform/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/paddlespeech/s2t/transform/add_deltas.py b/paddlespeech/s2t/transform/add_deltas.py deleted file mode 100644 index 1387fe9d..00000000 --- a/paddlespeech/s2t/transform/add_deltas.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import numpy as np - - -def delta(feat, window): - assert window > 0 - delta_feat = np.zeros_like(feat) - for i in range(1, window + 1): - delta_feat[:-i] += i * feat[i:] - delta_feat[i:] += -i * feat[:-i] - delta_feat[-i:] += i * feat[-1] - delta_feat[:i] += -i * feat[0] - delta_feat /= 2 * sum(i**2 for i in range(1, window + 1)) - return delta_feat - - -def add_deltas(x, window=2, order=2): - """ - Args: - x (np.ndarray): speech feat, (T, D). - - Return: - np.ndarray: (T, (1+order)*D) - """ - feats = [x] - for _ in range(order): - feats.append(delta(feats[-1], window)) - return np.concatenate(feats, axis=1) - - -class AddDeltas(): - def __init__(self, window=2, order=2): - self.window = window - self.order = order - - def __repr__(self): - return "{name}(window={window}, order={order}".format( - name=self.__class__.__name__, window=self.window, order=self.order) - - def __call__(self, x): - return add_deltas(x, window=self.window, order=self.order) diff --git a/paddlespeech/s2t/transform/channel_selector.py b/paddlespeech/s2t/transform/channel_selector.py deleted file mode 100644 index b078dcf8..00000000 --- a/paddlespeech/s2t/transform/channel_selector.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import numpy - - -class ChannelSelector(): - """Select 1ch from multi-channel signal""" - - def __init__(self, train_channel="random", eval_channel=0, axis=1): - self.train_channel = train_channel - self.eval_channel = eval_channel - self.axis = axis - - def __repr__(self): - return ("{name}(train_channel={train_channel}, " - "eval_channel={eval_channel}, axis={axis})".format( - name=self.__class__.__name__, - train_channel=self.train_channel, - eval_channel=self.eval_channel, - axis=self.axis, )) - - def __call__(self, x, train=True): - # Assuming x: [Time, Channel] by default - - if x.ndim <= self.axis: - # If the dimension is insufficient, then unsqueeze - # (e.g [Time] -> [Time, 1]) - ind = tuple( - slice(None) if i < x.ndim else None - for i in range(self.axis + 1)) - x = x[ind] - - if train: - channel = self.train_channel - else: - channel = self.eval_channel - - if channel == "random": - ch = numpy.random.randint(0, x.shape[self.axis]) - else: - ch = channel - - ind = tuple( - slice(None) if i != self.axis else ch for i in range(x.ndim)) - return x[ind] diff --git a/paddlespeech/s2t/transform/cmvn.py b/paddlespeech/s2t/transform/cmvn.py deleted file mode 100644 index 2db0070b..00000000 --- a/paddlespeech/s2t/transform/cmvn.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import io -import json - -import h5py -import kaldiio -import numpy as np - - -class CMVN(): - "Apply Global/Spk CMVN/iverserCMVN." - - def __init__( - self, - stats, - norm_means=True, - norm_vars=False, - filetype="mat", - utt2spk=None, - spk2utt=None, - reverse=False, - std_floor=1.0e-20, ): - self.stats_file = stats - self.norm_means = norm_means - self.norm_vars = norm_vars - self.reverse = reverse - - if isinstance(stats, dict): - stats_dict = dict(stats) - else: - # Use for global CMVN - if filetype == "mat": - stats_dict = {None: kaldiio.load_mat(stats)} - # Use for global CMVN - elif filetype == "npy": - stats_dict = {None: np.load(stats)} - # Use for speaker CMVN - elif filetype == "ark": - self.accept_uttid = True - stats_dict = dict(kaldiio.load_ark(stats)) - # Use for speaker CMVN - elif filetype == "hdf5": - self.accept_uttid = True - stats_dict = h5py.File(stats) - else: - raise ValueError("Not supporting filetype={}".format(filetype)) - - if utt2spk is not None: - self.utt2spk = {} - with io.open(utt2spk, "r", encoding="utf-8") as f: - for line in f: - utt, spk = line.rstrip().split(None, 1) - self.utt2spk[utt] = spk - elif spk2utt is not None: - self.utt2spk = {} - with io.open(spk2utt, "r", encoding="utf-8") as f: - for line in f: - spk, utts = line.rstrip().split(None, 1) - for utt in utts.split(): - self.utt2spk[utt] = spk - else: - self.utt2spk = None - - # Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1), - # and the first vector contains the sum of feats and the second is - # the sum of squares. The last value of the first, i.e. stats[0,-1], - # is the number of samples for this statistics. - self.bias = {} - self.scale = {} - for spk, stats in stats_dict.items(): - assert len(stats) == 2, stats.shape - - count = stats[0, -1] - - # If the feature has two or more dimensions - if not (np.isscalar(count) or isinstance(count, (int, float))): - # The first is only used - count = count.flatten()[0] - - mean = stats[0, :-1] / count - # V(x) = E(x^2) - (E(x))^2 - var = stats[1, :-1] / count - mean * mean - std = np.maximum(np.sqrt(var), std_floor) - self.bias[spk] = -mean - self.scale[spk] = 1 / std - - def __repr__(self): - return ("{name}(stats_file={stats_file}, " - "norm_means={norm_means}, norm_vars={norm_vars}, " - "reverse={reverse})".format( - name=self.__class__.__name__, - stats_file=self.stats_file, - norm_means=self.norm_means, - norm_vars=self.norm_vars, - reverse=self.reverse, )) - - def __call__(self, x, uttid=None): - if self.utt2spk is not None: - spk = self.utt2spk[uttid] - else: - spk = uttid - - if not self.reverse: - # apply cmvn - if self.norm_means: - x = np.add(x, self.bias[spk]) - if self.norm_vars: - x = np.multiply(x, self.scale[spk]) - - else: - # apply reverse cmvn - if self.norm_vars: - x = np.divide(x, self.scale[spk]) - if self.norm_means: - x = np.subtract(x, self.bias[spk]) - - return x - - -class UtteranceCMVN(): - "Apply Utterance CMVN" - - def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20): - self.norm_means = norm_means - self.norm_vars = norm_vars - self.std_floor = std_floor - - def __repr__(self): - return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format( - name=self.__class__.__name__, - norm_means=self.norm_means, - norm_vars=self.norm_vars, ) - - def __call__(self, x, uttid=None): - # x: [Time, Dim] - square_sums = (x**2).sum(axis=0) - mean = x.mean(axis=0) - - if self.norm_means: - x = np.subtract(x, mean) - - if self.norm_vars: - var = square_sums / x.shape[0] - mean**2 - std = np.maximum(np.sqrt(var), self.std_floor) - x = np.divide(x, std) - - return x - - -class GlobalCMVN(): - "Apply Global CMVN" - - def __init__(self, - cmvn_path, - norm_means=True, - norm_vars=True, - std_floor=1.0e-20): - # cmvn_path: Option[str, dict] - cmvn = cmvn_path - self.cmvn = cmvn - self.norm_means = norm_means - self.norm_vars = norm_vars - self.std_floor = std_floor - if isinstance(cmvn, dict): - cmvn_stats = cmvn - else: - with open(cmvn) as f: - cmvn_stats = json.load(f) - self.count = cmvn_stats['frame_num'] - self.mean = np.array(cmvn_stats['mean_stat']) / self.count - self.square_sums = np.array(cmvn_stats['var_stat']) - self.var = self.square_sums / self.count - self.mean**2 - self.std = np.maximum(np.sqrt(self.var), self.std_floor) - - def __repr__(self): - return f"""{self.__class__.__name__}( - cmvn_path={self.cmvn}, - norm_means={self.norm_means}, - norm_vars={self.norm_vars},)""" - - def __call__(self, x, uttid=None): - # x: [Time, Dim] - if self.norm_means: - x = np.subtract(x, self.mean) - - if self.norm_vars: - x = np.divide(x, self.std) - return x diff --git a/paddlespeech/s2t/transform/functional.py b/paddlespeech/s2t/transform/functional.py deleted file mode 100644 index ccb50081..00000000 --- a/paddlespeech/s2t/transform/functional.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import inspect - -from paddlespeech.s2t.transform.transform_interface import TransformInterface -from paddlespeech.s2t.utils.check_kwargs import check_kwargs - - -class FuncTrans(TransformInterface): - """Functional Transformation - - WARNING: - Builtin or C/C++ functions may not work properly - because this class heavily depends on the `inspect` module. - - Usage: - - >>> def foo_bar(x, a=1, b=2): - ... '''Foo bar - ... :param x: input - ... :param int a: default 1 - ... :param int b: default 2 - ... ''' - ... return x + a - b - - - >>> class FooBar(FuncTrans): - ... _func = foo_bar - ... __doc__ = foo_bar.__doc__ - """ - - _func = None - - def __init__(self, **kwargs): - self.kwargs = kwargs - check_kwargs(self.func, kwargs) - - def __call__(self, x): - return self.func(x, **self.kwargs) - - @classmethod - def add_arguments(cls, parser): - fname = cls._func.__name__.replace("_", "-") - group = parser.add_argument_group(fname + " transformation setting") - for k, v in cls.default_params().items(): - # TODO(karita): get help and choices from docstring? - attr = k.replace("_", "-") - group.add_argument(f"--{fname}-{attr}", default=v, type=type(v)) - return parser - - @property - def func(self): - return type(self)._func - - @classmethod - def default_params(cls): - try: - d = dict(inspect.signature(cls._func).parameters) - except ValueError: - d = dict() - return { - k: v.default - for k, v in d.items() if v.default != inspect.Parameter.empty - } - - def __repr__(self): - params = self.default_params() - params.update(**self.kwargs) - ret = self.__class__.__name__ + "(" - if len(params) == 0: - return ret + ")" - for k, v in params.items(): - ret += "{}={}, ".format(k, v) - return ret[:-2] + ")" diff --git a/paddlespeech/s2t/transform/perturb.py b/paddlespeech/s2t/transform/perturb.py deleted file mode 100644 index b18caefb..00000000 --- a/paddlespeech/s2t/transform/perturb.py +++ /dev/null @@ -1,471 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import librosa -import numpy -import scipy -import soundfile - -from paddlespeech.s2t.io.reader import SoundHDF5File - - -class SpeedPerturbation(): - """SpeedPerturbation - - The speed perturbation in kaldi uses sox-speed instead of sox-tempo, - and sox-speed just to resample the input, - i.e pitch and tempo are changed both. - - "Why use speed option instead of tempo -s in SoX for speed perturbation" - https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 - - Warning: - This function is very slow because of resampling. - I recommmend to apply speed-perturb outside the training using sox. - - """ - - def __init__( - self, - lower=0.9, - upper=1.1, - utt2ratio=None, - keep_length=True, - res_type="kaiser_best", - seed=None, ): - self.res_type = res_type - self.keep_length = keep_length - self.state = numpy.random.RandomState(seed) - - if utt2ratio is not None: - self.utt2ratio = {} - # Use the scheduled ratio for each utterances - self.utt2ratio_file = utt2ratio - self.lower = None - self.upper = None - self.accept_uttid = True - - with open(utt2ratio, "r") as f: - for line in f: - utt, ratio = line.rstrip().split(None, 1) - ratio = float(ratio) - self.utt2ratio[utt] = ratio - else: - self.utt2ratio = None - # The ratio is given on runtime randomly - self.lower = lower - self.upper = upper - - def __repr__(self): - if self.utt2ratio is None: - return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format( - self.__class__.__name__, - self.lower, - self.upper, - self.keep_length, - self.res_type, ) - else: - return "{}({}, res_type={})".format( - self.__class__.__name__, self.utt2ratio_file, self.res_type) - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - x = x.astype(numpy.float32) - if self.accept_uttid: - ratio = self.utt2ratio[uttid] - else: - ratio = self.state.uniform(self.lower, self.upper) - - # Note1: resample requires the sampling-rate of input and output, - # but actually only the ratio is used. - y = librosa.resample( - x, orig_sr=ratio, target_sr=1, res_type=self.res_type) - - if self.keep_length: - diff = abs(len(x) - len(y)) - if len(y) > len(x): - # Truncate noise - y = y[diff // 2:-((diff + 1) // 2)] - elif len(y) < len(x): - # Assume the time-axis is the first: (Time, Channel) - pad_width = [(diff // 2, (diff + 1) // 2)] + [ - (0, 0) for _ in range(y.ndim - 1) - ] - y = numpy.pad( - y, pad_width=pad_width, constant_values=0, mode="constant") - return y - - -class SpeedPerturbationSox(): - """SpeedPerturbationSox - - The speed perturbation in kaldi uses sox-speed instead of sox-tempo, - and sox-speed just to resample the input, - i.e pitch and tempo are changed both. - - To speed up or slow down the sound of a file, - use speed to modify the pitch and the duration of the file. - This raises the speed and reduces the time. - The default factor is 1.0 which makes no change to the audio. - 2.0 doubles speed, thus time length is cut by a half and pitch is one interval higher. - - "Why use speed option instead of tempo -s in SoX for speed perturbation" - https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 - - tempo option: - sox -t wav input.wav -t wav output.tempo0.9.wav tempo -s 0.9 - - speed option: - sox -t wav input.wav -t wav output.speed0.9.wav speed 0.9 - - If we use speed option like above, the pitch of audio also will be changed, - but the tempo option does not change the pitch. - """ - - def __init__( - self, - lower=0.9, - upper=1.1, - utt2ratio=None, - keep_length=True, - sr=16000, - seed=None, ): - self.sr = sr - self.keep_length = keep_length - self.state = numpy.random.RandomState(seed) - - try: - import soxbindings as sox - except ImportError: - try: - from paddlespeech.s2t.utils import dynamic_pip_install - package = "sox" - dynamic_pip_install.install(package) - package = "soxbindings" - if sys.platform != "win32": - dynamic_pip_install.install(package) - import soxbindings as sox - except Exception: - raise RuntimeError( - "Can not install soxbindings on your system.") - self.sox = sox - - if utt2ratio is not None: - self.utt2ratio = {} - # Use the scheduled ratio for each utterances - self.utt2ratio_file = utt2ratio - self.lower = None - self.upper = None - self.accept_uttid = True - - with open(utt2ratio, "r") as f: - for line in f: - utt, ratio = line.rstrip().split(None, 1) - ratio = float(ratio) - self.utt2ratio[utt] = ratio - else: - self.utt2ratio = None - # The ratio is given on runtime randomly - self.lower = lower - self.upper = upper - - def __repr__(self): - if self.utt2ratio is None: - return f"""{self.__class__.__name__}( - lower={self.lower}, - upper={self.upper}, - keep_length={self.keep_length}, - sample_rate={self.sr})""" - - else: - return f"""{self.__class__.__name__}( - utt2ratio={self.utt2ratio_file}, - sample_rate={self.sr})""" - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - - x = x.astype(numpy.float32) - if self.accept_uttid: - ratio = self.utt2ratio[uttid] - else: - ratio = self.state.uniform(self.lower, self.upper) - - tfm = self.sox.Transformer() - tfm.set_globals(multithread=False) - tfm.speed(ratio) - y = tfm.build_array(input_array=x, sample_rate_in=self.sr) - - if self.keep_length: - diff = abs(len(x) - len(y)) - if len(y) > len(x): - # Truncate noise - y = y[diff // 2:-((diff + 1) // 2)] - elif len(y) < len(x): - # Assume the time-axis is the first: (Time, Channel) - pad_width = [(diff // 2, (diff + 1) // 2)] + [ - (0, 0) for _ in range(y.ndim - 1) - ] - y = numpy.pad( - y, pad_width=pad_width, constant_values=0, mode="constant") - - if y.ndim == 2 and x.ndim == 1: - # (T, C) -> (T) - y = y.sequence(1) - return y - - -class BandpassPerturbation(): - """BandpassPerturbation - - Randomly dropout along the frequency axis. - - The original idea comes from the following: - "randomly-selected frequency band was cut off under the constraint of - leaving at least 1,000 Hz band within the range of less than 4,000Hz." - (The Hitachi/JHU CHiME-5 system: Advances in speech recognition for - everyday home environments using multiple microphone arrays; - http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf) - - """ - - def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1, )): - self.lower = lower - self.upper = upper - self.state = numpy.random.RandomState(seed) - # x_stft: (Time, Channel, Freq) - self.axes = axes - - def __repr__(self): - return "{}(lower={}, upper={})".format(self.__class__.__name__, - self.lower, self.upper) - - def __call__(self, x_stft, uttid=None, train=True): - if not train: - return x_stft - - if x_stft.ndim == 1: - raise RuntimeError("Input in time-freq domain: " - "(Time, Channel, Freq) or (Time, Freq)") - - ratio = self.state.uniform(self.lower, self.upper) - axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes] - shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)] - - mask = self.state.randn(*shape) > ratio - x_stft *= mask - return x_stft - - -class VolumePerturbation(): - def __init__(self, - lower=-1.6, - upper=1.6, - utt2ratio=None, - dbunit=True, - seed=None): - self.dbunit = dbunit - self.utt2ratio_file = utt2ratio - self.lower = lower - self.upper = upper - self.state = numpy.random.RandomState(seed) - - if utt2ratio is not None: - # Use the scheduled ratio for each utterances - self.utt2ratio = {} - self.lower = None - self.upper = None - self.accept_uttid = True - - with open(utt2ratio, "r") as f: - for line in f: - utt, ratio = line.rstrip().split(None, 1) - ratio = float(ratio) - self.utt2ratio[utt] = ratio - else: - # The ratio is given on runtime randomly - self.utt2ratio = None - - def __repr__(self): - if self.utt2ratio is None: - return "{}(lower={}, upper={}, dbunit={})".format( - self.__class__.__name__, self.lower, self.upper, self.dbunit) - else: - return '{}("{}", dbunit={})'.format( - self.__class__.__name__, self.utt2ratio_file, self.dbunit) - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - - x = x.astype(numpy.float32) - - if self.accept_uttid: - ratio = self.utt2ratio[uttid] - else: - ratio = self.state.uniform(self.lower, self.upper) - if self.dbunit: - ratio = 10**(ratio / 20) - return x * ratio - - -class NoiseInjection(): - """Add isotropic noise""" - - def __init__( - self, - utt2noise=None, - lower=-20, - upper=-5, - utt2ratio=None, - filetype="list", - dbunit=True, - seed=None, ): - self.utt2noise_file = utt2noise - self.utt2ratio_file = utt2ratio - self.filetype = filetype - self.dbunit = dbunit - self.lower = lower - self.upper = upper - self.state = numpy.random.RandomState(seed) - - if utt2ratio is not None: - # Use the scheduled ratio for each utterances - self.utt2ratio = {} - with open(utt2noise, "r") as f: - for line in f: - utt, snr = line.rstrip().split(None, 1) - snr = float(snr) - self.utt2ratio[utt] = snr - else: - # The ratio is given on runtime randomly - self.utt2ratio = None - - if utt2noise is not None: - self.utt2noise = {} - if filetype == "list": - with open(utt2noise, "r") as f: - for line in f: - utt, filename = line.rstrip().split(None, 1) - signal, rate = soundfile.read(filename, dtype="int16") - # Load all files in memory - self.utt2noise[utt] = (signal, rate) - - elif filetype == "sound.hdf5": - self.utt2noise = SoundHDF5File(utt2noise, "r") - else: - raise ValueError(filetype) - else: - self.utt2noise = None - - if utt2noise is not None and utt2ratio is not None: - if set(self.utt2ratio) != set(self.utt2noise): - raise RuntimeError("The uttids mismatch between {} and {}". - format(utt2ratio, utt2noise)) - - def __repr__(self): - if self.utt2ratio is None: - return "{}(lower={}, upper={}, dbunit={})".format( - self.__class__.__name__, self.lower, self.upper, self.dbunit) - else: - return '{}("{}", dbunit={})'.format( - self.__class__.__name__, self.utt2ratio_file, self.dbunit) - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - x = x.astype(numpy.float32) - - # 1. Get ratio of noise to signal in sound pressure level - if uttid is not None and self.utt2ratio is not None: - ratio = self.utt2ratio[uttid] - else: - ratio = self.state.uniform(self.lower, self.upper) - - if self.dbunit: - ratio = 10**(ratio / 20) - scale = ratio * numpy.sqrt((x**2).mean()) - - # 2. Get noise - if self.utt2noise is not None: - # Get noise from the external source - if uttid is not None: - noise, rate = self.utt2noise[uttid] - else: - # Randomly select the noise source - noise = self.state.choice(list(self.utt2noise.values())) - # Normalize the level - noise /= numpy.sqrt((noise**2).mean()) - - # Adjust the noise length - diff = abs(len(x) - len(noise)) - offset = self.state.randint(0, diff) - if len(noise) > len(x): - # Truncate noise - noise = noise[offset:-(diff - offset)] - else: - noise = numpy.pad( - noise, pad_width=[offset, diff - offset], mode="wrap") - - else: - # Generate white noise - noise = self.state.normal(0, 1, x.shape) - - # 3. Add noise to signal - return x + noise * scale - - -class RIRConvolve(): - def __init__(self, utt2rir, filetype="list"): - self.utt2rir_file = utt2rir - self.filetype = filetype - - self.utt2rir = {} - if filetype == "list": - with open(utt2rir, "r") as f: - for line in f: - utt, filename = line.rstrip().split(None, 1) - signal, rate = soundfile.read(filename, dtype="int16") - self.utt2rir[utt] = (signal, rate) - - elif filetype == "sound.hdf5": - self.utt2rir = SoundHDF5File(utt2rir, "r") - else: - raise NotImplementedError(filetype) - - def __repr__(self): - return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file) - - def __call__(self, x, uttid=None, train=True): - if not train: - return x - - x = x.astype(numpy.float32) - - if x.ndim != 1: - # Must be single channel - raise RuntimeError( - "Input x must be one dimensional array, but got {}".format( - x.shape)) - - rir, rate = self.utt2rir[uttid] - if rir.ndim == 2: - # FIXME(kamo): Use chainer.convolution_1d? - # return [Time, Channel] - return numpy.stack( - [scipy.convolve(x, r, mode="same") for r in rir], axis=-1) - else: - return scipy.convolve(x, rir, mode="same") diff --git a/paddlespeech/s2t/transform/spec_augment.py b/paddlespeech/s2t/transform/spec_augment.py deleted file mode 100644 index 5ce95085..00000000 --- a/paddlespeech/s2t/transform/spec_augment.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -"""Spec Augment module for preprocessing i.e., data augmentation""" -import random - -import numpy -from PIL import Image -from PIL.Image import BICUBIC - -from paddlespeech.s2t.transform.functional import FuncTrans - - -def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): - """time warp for spec augment - - move random center frame by the random width ~ uniform(-window, window) - :param numpy.ndarray x: spectrogram (time, freq) - :param int max_time_warp: maximum time frames to warp - :param bool inplace: overwrite x with the result - :param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp" - (slow, differentiable) - :returns numpy.ndarray: time warped spectrogram (time, freq) - """ - window = max_time_warp - if window == 0: - return x - - if mode == "PIL": - t = x.shape[0] - if t - window <= window: - return x - # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 - center = random.randrange(window, t - window) - warped = random.randrange(center - window, center + - window) + 1 # 1 ... t - 1 - - left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) - right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), - BICUBIC) - if inplace: - x[:warped] = left - x[warped:] = right - return x - return numpy.concatenate((left, right), 0) - elif mode == "sparse_image_warp": - import paddle - - from espnet.utils import spec_augment - - # TODO(karita): make this differentiable again - return spec_augment.time_warp(paddle.to_tensor(x), window).numpy() - else: - raise NotImplementedError("unknown resize mode: " + mode + - ", choose one from (PIL, sparse_image_warp).") - - -class TimeWarp(FuncTrans): - _func = time_warp - __doc__ = time_warp.__doc__ - - def __call__(self, x, train): - if not train: - return x - return super().__call__(x) - - -def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False): - """freq mask for spec agument - - :param numpy.ndarray x: (time, freq) - :param int n_mask: the number of masks - :param bool inplace: overwrite - :param bool replace_with_zero: pad zero on mask if true else use mean - """ - if inplace: - cloned = x - else: - cloned = x.copy() - - num_mel_channels = cloned.shape[1] - fs = numpy.random.randint(0, F, size=(n_mask, 2)) - - for f, mask_end in fs: - f_zero = random.randrange(0, num_mel_channels - f) - mask_end += f_zero - - # avoids randrange error if values are equal and range is empty - if f_zero == f_zero + f: - continue - - if replace_with_zero: - cloned[:, f_zero:mask_end] = 0 - else: - cloned[:, f_zero:mask_end] = cloned.mean() - return cloned - - -class FreqMask(FuncTrans): - _func = freq_mask - __doc__ = freq_mask.__doc__ - - def __call__(self, x, train): - if not train: - return x - return super().__call__(x) - - -def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False): - """freq mask for spec agument - - :param numpy.ndarray spec: (time, freq) - :param int n_mask: the number of masks - :param bool inplace: overwrite - :param bool replace_with_zero: pad zero on mask if true else use mean - """ - if inplace: - cloned = spec - else: - cloned = spec.copy() - len_spectro = cloned.shape[0] - ts = numpy.random.randint(0, T, size=(n_mask, 2)) - for t, mask_end in ts: - # avoid randint range error - if len_spectro - t <= 0: - continue - t_zero = random.randrange(0, len_spectro - t) - - # avoids randrange error if values are equal and range is empty - if t_zero == t_zero + t: - continue - - mask_end += t_zero - if replace_with_zero: - cloned[t_zero:mask_end] = 0 - else: - cloned[t_zero:mask_end] = cloned.mean() - return cloned - - -class TimeMask(FuncTrans): - _func = time_mask - __doc__ = time_mask.__doc__ - - def __call__(self, x, train): - if not train: - return x - return super().__call__(x) - - -def spec_augment( - x, - resize_mode="PIL", - max_time_warp=80, - max_freq_width=27, - n_freq_mask=2, - max_time_width=100, - n_time_mask=2, - inplace=True, - replace_with_zero=True, ): - """spec agument - - apply random time warping and time/freq masking - default setting is based on LD (Librispeech double) in Table 2 - https://arxiv.org/pdf/1904.08779.pdf - - :param numpy.ndarray x: (time, freq) - :param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp" - (slow, differentiable) - :param int max_time_warp: maximum frames to warp the center frame in spectrogram (W) - :param int freq_mask_width: maximum width of the random freq mask (F) - :param int n_freq_mask: the number of the random freq mask (m_F) - :param int time_mask_width: maximum width of the random time mask (T) - :param int n_time_mask: the number of the random time mask (m_T) - :param bool inplace: overwrite intermediate array - :param bool replace_with_zero: pad zero on mask if true else use mean - """ - assert isinstance(x, numpy.ndarray) - assert x.ndim == 2 - x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode) - x = freq_mask( - x, - max_freq_width, - n_freq_mask, - inplace=inplace, - replace_with_zero=replace_with_zero, ) - x = time_mask( - x, - max_time_width, - n_time_mask, - inplace=inplace, - replace_with_zero=replace_with_zero, ) - return x - - -class SpecAugment(FuncTrans): - _func = spec_augment - __doc__ = spec_augment.__doc__ - - def __call__(self, x, train): - if not train: - return x - return super().__call__(x) diff --git a/paddlespeech/s2t/transform/spectrogram.py b/paddlespeech/s2t/transform/spectrogram.py deleted file mode 100644 index 19f0237b..00000000 --- a/paddlespeech/s2t/transform/spectrogram.py +++ /dev/null @@ -1,475 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -import librosa -import numpy as np -import paddle -from python_speech_features import logfbank - -import paddlespeech.audio.compliance.kaldi as kaldi - - -def stft(x, - n_fft, - n_shift, - win_length=None, - window="hann", - center=True, - pad_mode="reflect"): - # x: [Time, Channel] - if x.ndim == 1: - single_channel = True - # x: [Time] -> [Time, Channel] - x = x[:, None] - else: - single_channel = False - x = x.astype(np.float32) - - # FIXME(kamo): librosa.stft can't use multi-channel? - # x: [Time, Channel, Freq] - x = np.stack( - [ - librosa.stft( - y=x[:, ch], - n_fft=n_fft, - hop_length=n_shift, - win_length=win_length, - window=window, - center=center, - pad_mode=pad_mode, ).T for ch in range(x.shape[1]) - ], - axis=1, ) - - if single_channel: - # x: [Time, Channel, Freq] -> [Time, Freq] - x = x[:, 0] - return x - - -def istft(x, n_shift, win_length=None, window="hann", center=True): - # x: [Time, Channel, Freq] - if x.ndim == 2: - single_channel = True - # x: [Time, Freq] -> [Time, Channel, Freq] - x = x[:, None, :] - else: - single_channel = False - - # x: [Time, Channel] - x = np.stack( - [ - librosa.istft( - stft_matrix=x[:, ch].T, # [Time, Freq] -> [Freq, Time] - hop_length=n_shift, - win_length=win_length, - window=window, - center=center, ) for ch in range(x.shape[1]) - ], - axis=1, ) - - if single_channel: - # x: [Time, Channel] -> [Time] - x = x[:, 0] - return x - - -def stft2logmelspectrogram(x_stft, - fs, - n_mels, - n_fft, - fmin=None, - fmax=None, - eps=1e-10): - # x_stft: (Time, Channel, Freq) or (Time, Freq) - fmin = 0 if fmin is None else fmin - fmax = fs / 2 if fmax is None else fmax - - # spc: (Time, Channel, Freq) or (Time, Freq) - spc = np.abs(x_stft) - # mel_basis: (Mel_freq, Freq) - mel_basis = librosa.filters.mel( - sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) - # lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq) - lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) - - return lmspc - - -def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"): - # x: (Time, Channel) -> spc: (Time, Channel, Freq) - spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window)) - return spc - - -def logmelspectrogram( - x, - fs, - n_mels, - n_fft, - n_shift, - win_length=None, - window="hann", - fmin=None, - fmax=None, - eps=1e-10, - pad_mode="reflect", ): - # stft: (Time, Channel, Freq) or (Time, Freq) - x_stft = stft( - x, - n_fft=n_fft, - n_shift=n_shift, - win_length=win_length, - window=window, - pad_mode=pad_mode, ) - - return stft2logmelspectrogram( - x_stft, - fs=fs, - n_mels=n_mels, - n_fft=n_fft, - fmin=fmin, - fmax=fmax, - eps=eps) - - -class Spectrogram(): - def __init__(self, n_fft, n_shift, win_length=None, window="hann"): - self.n_fft = n_fft - self.n_shift = n_shift - self.win_length = win_length - self.window = window - - def __repr__(self): - return ("{name}(n_fft={n_fft}, n_shift={n_shift}, " - "win_length={win_length}, window={window})".format( - name=self.__class__.__name__, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, )) - - def __call__(self, x): - return spectrogram( - x, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, ) - - -class LogMelSpectrogram(): - def __init__( - self, - fs, - n_mels, - n_fft, - n_shift, - win_length=None, - window="hann", - fmin=None, - fmax=None, - eps=1e-10, ): - self.fs = fs - self.n_mels = n_mels - self.n_fft = n_fft - self.n_shift = n_shift - self.win_length = win_length - self.window = window - self.fmin = fmin - self.fmax = fmax - self.eps = eps - - def __repr__(self): - return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " - "n_shift={n_shift}, win_length={win_length}, window={window}, " - "fmin={fmin}, fmax={fmax}, eps={eps}))".format( - name=self.__class__.__name__, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, - fmin=self.fmin, - fmax=self.fmax, - eps=self.eps, )) - - def __call__(self, x): - return logmelspectrogram( - x, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, ) - - -class Stft2LogMelSpectrogram(): - def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): - self.fs = fs - self.n_mels = n_mels - self.n_fft = n_fft - self.fmin = fmin - self.fmax = fmax - self.eps = eps - - def __repr__(self): - return ("{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " - "fmin={fmin}, fmax={fmax}, eps={eps}))".format( - name=self.__class__.__name__, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - fmin=self.fmin, - fmax=self.fmax, - eps=self.eps, )) - - def __call__(self, x): - return stft2logmelspectrogram( - x, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - fmin=self.fmin, - fmax=self.fmax, ) - - -class Stft(): - def __init__( - self, - n_fft, - n_shift, - win_length=None, - window="hann", - center=True, - pad_mode="reflect", ): - self.n_fft = n_fft - self.n_shift = n_shift - self.win_length = win_length - self.window = window - self.center = center - self.pad_mode = pad_mode - - def __repr__(self): - return ("{name}(n_fft={n_fft}, n_shift={n_shift}, " - "win_length={win_length}, window={window}," - "center={center}, pad_mode={pad_mode})".format( - name=self.__class__.__name__, - n_fft=self.n_fft, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode=self.pad_mode, )) - - def __call__(self, x): - return stft( - x, - self.n_fft, - self.n_shift, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode=self.pad_mode, ) - - -class IStft(): - def __init__(self, n_shift, win_length=None, window="hann", center=True): - self.n_shift = n_shift - self.win_length = win_length - self.window = window - self.center = center - - def __repr__(self): - return ("{name}(n_shift={n_shift}, " - "win_length={win_length}, window={window}," - "center={center})".format( - name=self.__class__.__name__, - n_shift=self.n_shift, - win_length=self.win_length, - window=self.window, - center=self.center, )) - - def __call__(self, x): - return istft( - x, - self.n_shift, - win_length=self.win_length, - window=self.window, - center=self.center, ) - - -class LogMelSpectrogramKaldi(): - def __init__( - self, - fs=16000, - n_mels=80, - n_shift=160, # unit:sample, 10ms - win_length=400, # unit:sample, 25ms - energy_floor=0.0, - dither=0.1): - """ - The Kaldi implementation of LogMelSpectrogram - Args: - fs (int): sample rate of the audio - n_mels (int): number of mel filter banks - n_shift (int): number of points in a frame shift - win_length (int): number of points in a frame windows - energy_floor (float): Floor on energy in Spectrogram computation (absolute) - dither (float): Dithering constant - - Returns: - LogMelSpectrogramKaldi - """ - - self.fs = fs - self.n_mels = n_mels - num_point_ms = fs / 1000 - self.n_frame_length = win_length / num_point_ms - self.n_frame_shift = n_shift / num_point_ms - self.energy_floor = energy_floor - self.dither = dither - - def __repr__(self): - return ( - "{name}(fs={fs}, n_mels={n_mels}, " - "n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, " - "dither={dither}))".format( - name=self.__class__.__name__, - fs=self.fs, - n_mels=self.n_mels, - n_frame_shift=self.n_frame_shift, - n_frame_length=self.n_frame_length, - dither=self.dither, )) - - def __call__(self, x, train): - """ - Args: - x (np.ndarray): shape (Ti,) - train (bool): True, train mode. - - Raises: - ValueError: not support (Ti, C) - - Returns: - np.ndarray: (T, D) - """ - dither = self.dither if train else 0.0 - if x.ndim != 1: - raise ValueError("Not support x: [Time, Channel]") - waveform = paddle.to_tensor(np.expand_dims(x, 0), dtype=paddle.float32) - mat = kaldi.fbank( - waveform, - n_mels=self.n_mels, - frame_length=self.n_frame_length, - frame_shift=self.n_frame_shift, - dither=dither, - energy_floor=self.energy_floor, - sr=self.fs) - mat = np.squeeze(mat.numpy()) - return mat - - -class LogMelSpectrogramKaldi_decay(): - def __init__( - self, - fs=16000, - n_mels=80, - n_fft=512, # fft point - n_shift=160, # unit:sample, 10ms - win_length=400, # unit:sample, 25ms - window="povey", - fmin=20, - fmax=None, - eps=1e-10, - dither=1.0): - self.fs = fs - self.n_mels = n_mels - self.n_fft = n_fft - if n_shift > win_length: - raise ValueError("Stride size must not be greater than " - "window size.") - self.n_shift = n_shift / fs # unit: ms - self.win_length = win_length / fs # unit: ms - - self.window = window - self.fmin = fmin - if fmax is None: - fmax_ = fmax if fmax else self.fs / 2 - elif fmax > int(self.fs / 2): - raise ValueError("fmax must not be greater than half of " - "sample rate.") - self.fmax = fmax_ - - self.eps = eps - self.remove_dc_offset = True - self.preemph = 0.97 - self.dither = dither # only work in train mode - - def __repr__(self): - return ( - "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " - "n_shift={n_shift}, win_length={win_length}, preemph={preemph}, window={window}, " - "fmin={fmin}, fmax={fmax}, eps={eps}, dither={dither}))".format( - name=self.__class__.__name__, - fs=self.fs, - n_mels=self.n_mels, - n_fft=self.n_fft, - n_shift=self.n_shift, - preemph=self.preemph, - win_length=self.win_length, - window=self.window, - fmin=self.fmin, - fmax=self.fmax, - eps=self.eps, - dither=self.dither, )) - - def __call__(self, x, train): - """ - - Args: - x (np.ndarray): shape (Ti,) - train (bool): True, train mode. - - Raises: - ValueError: not support (Ti, C) - - Returns: - np.ndarray: (T, D) - """ - dither = self.dither if train else 0.0 - if x.ndim != 1: - raise ValueError("Not support x: [Time, Channel]") - - if x.dtype in np.sctypes['float']: - # PCM32 -> PCM16 - bits = np.iinfo(np.int16).bits - x = x * 2**(bits - 1) - - # logfbank need PCM16 input - y = logfbank( - signal=x, - samplerate=self.fs, - winlen=self.win_length, # unit ms - winstep=self.n_shift, # unit ms - nfilt=self.n_mels, - nfft=self.n_fft, - lowfreq=self.fmin, - highfreq=self.fmax, - dither=dither, - remove_dc_offset=self.remove_dc_offset, - preemph=self.preemph, - wintype=self.window) - return y diff --git a/paddlespeech/s2t/transform/transform_interface.py b/paddlespeech/s2t/transform/transform_interface.py deleted file mode 100644 index 8bc62420..00000000 --- a/paddlespeech/s2t/transform/transform_interface.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) - - -class TransformInterface: - """Transform Interface""" - - def __call__(self, x): - raise NotImplementedError("__call__ method is not implemented") - - @classmethod - def add_arguments(cls, parser): - return parser - - def __repr__(self): - return self.__class__.__name__ + "()" - - -class Identity(TransformInterface): - """Identity Function""" - - def __call__(self, x): - return x diff --git a/paddlespeech/s2t/transform/transformation.py b/paddlespeech/s2t/transform/transformation.py deleted file mode 100644 index 3b433cb0..00000000 --- a/paddlespeech/s2t/transform/transformation.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -"""Transformation module.""" -import copy -import io -import logging -from collections import OrderedDict -from collections.abc import Sequence -from inspect import signature - -import yaml - -from paddlespeech.s2t.utils.dynamic_import import dynamic_import - -import_alias = dict( - identity="paddlespeech.s2t.transform.transform_interface:Identity", - time_warp="paddlespeech.s2t.transform.spec_augment:TimeWarp", - time_mask="paddlespeech.s2t.transform.spec_augment:TimeMask", - freq_mask="paddlespeech.s2t.transform.spec_augment:FreqMask", - spec_augment="paddlespeech.s2t.transform.spec_augment:SpecAugment", - speed_perturbation="paddlespeech.s2t.transform.perturb:SpeedPerturbation", - speed_perturbation_sox="paddlespeech.s2t.transform.perturb:SpeedPerturbationSox", - volume_perturbation="paddlespeech.s2t.transform.perturb:VolumePerturbation", - noise_injection="paddlespeech.s2t.transform.perturb:NoiseInjection", - bandpass_perturbation="paddlespeech.s2t.transform.perturb:BandpassPerturbation", - rir_convolve="paddlespeech.s2t.transform.perturb:RIRConvolve", - delta="paddlespeech.s2t.transform.add_deltas:AddDeltas", - cmvn="paddlespeech.s2t.transform.cmvn:CMVN", - utterance_cmvn="paddlespeech.s2t.transform.cmvn:UtteranceCMVN", - fbank="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogram", - spectrogram="paddlespeech.s2t.transform.spectrogram:Spectrogram", - stft="paddlespeech.s2t.transform.spectrogram:Stft", - istft="paddlespeech.s2t.transform.spectrogram:IStft", - stft2fbank="paddlespeech.s2t.transform.spectrogram:Stft2LogMelSpectrogram", - wpe="paddlespeech.s2t.transform.wpe:WPE", - channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector", - fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi", - cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN") - - -class Transformation(): - """Apply some functions to the mini-batch - - Examples: - >>> kwargs = {"process": [{"type": "fbank", - ... "n_mels": 80, - ... "fs": 16000}, - ... {"type": "cmvn", - ... "stats": "data/train/cmvn.ark", - ... "norm_vars": True}, - ... {"type": "delta", "window": 2, "order": 2}]} - >>> transform = Transformation(kwargs) - >>> bs = 10 - >>> xs = [np.random.randn(100, 80).astype(np.float32) - ... for _ in range(bs)] - >>> xs = transform(xs) - """ - - def __init__(self, conffile=None): - if conffile is not None: - if isinstance(conffile, dict): - self.conf = copy.deepcopy(conffile) - else: - with io.open(conffile, encoding="utf-8") as f: - self.conf = yaml.safe_load(f) - assert isinstance(self.conf, dict), type(self.conf) - else: - self.conf = {"mode": "sequential", "process": []} - - self.functions = OrderedDict() - if self.conf.get("mode", "sequential") == "sequential": - for idx, process in enumerate(self.conf["process"]): - assert isinstance(process, dict), type(process) - opts = dict(process) - process_type = opts.pop("type") - class_obj = dynamic_import(process_type, import_alias) - # TODO(karita): assert issubclass(class_obj, TransformInterface) - try: - self.functions[idx] = class_obj(**opts) - except TypeError: - try: - signa = signature(class_obj) - except ValueError: - # Some function, e.g. built-in function, are failed - pass - else: - logging.error("Expected signature: {}({})".format( - class_obj.__name__, signa)) - raise - else: - raise NotImplementedError( - "Not supporting mode={}".format(self.conf["mode"])) - - def __repr__(self): - rep = "\n" + "\n".join(" {}: {}".format(k, v) - for k, v in self.functions.items()) - return "{}({})".format(self.__class__.__name__, rep) - - def __call__(self, xs, uttid_list=None, **kwargs): - """Return new mini-batch - - :param Union[Sequence[np.ndarray], np.ndarray] xs: - :param Union[Sequence[str], str] uttid_list: - :return: batch: - :rtype: List[np.ndarray] - """ - if not isinstance(xs, Sequence): - is_batch = False - xs = [xs] - else: - is_batch = True - - if isinstance(uttid_list, str): - uttid_list = [uttid_list for _ in range(len(xs))] - - if self.conf.get("mode", "sequential") == "sequential": - for idx in range(len(self.conf["process"])): - func = self.functions[idx] - # TODO(karita): use TrainingTrans and UttTrans to check __call__ args - # Derive only the args which the func has - try: - param = signature(func).parameters - except ValueError: - # Some function, e.g. built-in function, are failed - param = {} - _kwargs = {k: v for k, v in kwargs.items() if k in param} - try: - if uttid_list is not None and "uttid" in param: - xs = [ - func(x, u, **_kwargs) - for x, u in zip(xs, uttid_list) - ] - else: - xs = [func(x, **_kwargs) for x in xs] - except Exception: - logging.fatal("Catch a exception from {}th func: {}".format( - idx, func)) - raise - else: - raise NotImplementedError( - "Not supporting mode={}".format(self.conf["mode"])) - - if is_batch: - return xs - else: - return xs[0] diff --git a/paddlespeech/s2t/transform/wpe.py b/paddlespeech/s2t/transform/wpe.py deleted file mode 100644 index 777379d0..00000000 --- a/paddlespeech/s2t/transform/wpe.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from espnet(https://github.com/espnet/espnet) -from nara_wpe.wpe import wpe - - -class WPE(object): - def __init__(self, - taps=10, - delay=3, - iterations=3, - psd_context=0, - statistics_mode="full"): - self.taps = taps - self.delay = delay - self.iterations = iterations - self.psd_context = psd_context - self.statistics_mode = statistics_mode - - def __repr__(self): - return ("{name}(taps={taps}, delay={delay}" - "iterations={iterations}, psd_context={psd_context}, " - "statistics_mode={statistics_mode})".format( - name=self.__class__.__name__, - taps=self.taps, - delay=self.delay, - iterations=self.iterations, - psd_context=self.psd_context, - statistics_mode=self.statistics_mode, )) - - def __call__(self, xs): - """Return enhanced - - :param np.ndarray xs: (Time, Channel, Frequency) - :return: enhanced_xs - :rtype: np.ndarray - - """ - # nara_wpe.wpe: (F, C, T) - xs = wpe( - xs.transpose((2, 1, 0)), - taps=self.taps, - delay=self.delay, - iterations=self.iterations, - psd_context=self.psd_context, - statistics_mode=self.statistics_mode, ) - return xs.transpose(2, 1, 0)