diff --git a/data_utils/data.py b/data_utils/data.py deleted file mode 100644 index 30819f578..000000000 --- a/data_utils/data.py +++ /dev/null @@ -1,373 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Contains data generator for orgnaizing various audio data preprocessing -pipeline and offering data reader interface of PaddlePaddle requirements. -""" - -import random -import tarfile -import multiprocessing -import numpy as np -import paddle.fluid as fluid -from threading import local -from data_utils.utility import read_manifest -from data_utils.augmentor.augmentation import AugmentationPipeline -from data_utils.featurizer.speech_featurizer import SpeechFeaturizer -from data_utils.speech import SpeechSegment -from data_utils.normalizer import FeatureNormalizer - -__all__ = ['DataGenerator'] - - -class DataGenerator(): - """ - DataGenerator provides basic audio data preprocessing pipeline, and offers - data reader interfaces of PaddlePaddle requirements. - - :param vocab_filepath: Vocabulary filepath for indexing tokenized - transcripts. - :type vocab_filepath: str - :param mean_std_filepath: File containing the pre-computed mean and stddev. - :type mean_std_filepath: None|str - :param augmentation_config: Augmentation configuration in json string. - Details see AugmentationPipeline.__doc__. - :type augmentation_config: str - :param max_duration: Audio with duration (in seconds) greater than - this will be discarded. - :type max_duration: float - :param min_duration: Audio with duration (in seconds) smaller than - this will be discarded. - :type min_duration: float - :param stride_ms: Striding size (in milliseconds) for generating frames. - :type stride_ms: float - :param window_ms: Window size (in milliseconds) for generating frames. - :type window_ms: float - :param max_freq: Used when specgram_type is 'linear', only FFT bins - corresponding to frequencies between [0, max_freq] are - returned. - :types max_freq: None|float - :param specgram_type: Specgram feature type. Options: 'linear'. - :type specgram_type: str - :param use_dB_normalization: Whether to normalize the audio to -20 dB - before extracting the features. - :type use_dB_normalization: bool - :param random_seed: Random seed. - :type random_seed: int - :param keep_transcription_text: If set to True, transcription text will - be passed forward directly without - converting to index sequence. - :type keep_transcription_text: bool - :param place: The place to run the program. - :type place: CPUPlace or CUDAPlace - :param is_training: If set to True, generate text data for training, - otherwise, generate text data for infer. - :type is_training: bool - """ - - def __init__(self, - vocab_filepath, - mean_std_filepath, - augmentation_config='{}', - max_duration=float('inf'), - min_duration=0.0, - stride_ms=10.0, - window_ms=20.0, - max_freq=None, - specgram_type='linear', - use_dB_normalization=True, - random_seed=0, - keep_transcription_text=False, - place=fluid.CPUPlace(), - is_training=True): - self._max_duration = max_duration - self._min_duration = min_duration - self._normalizer = FeatureNormalizer(mean_std_filepath) - self._augmentation_pipeline = AugmentationPipeline( - augmentation_config=augmentation_config, random_seed=random_seed) - self._speech_featurizer = SpeechFeaturizer( - vocab_filepath=vocab_filepath, - specgram_type=specgram_type, - stride_ms=stride_ms, - window_ms=window_ms, - max_freq=max_freq, - use_dB_normalization=use_dB_normalization) - self._rng = random.Random(random_seed) - self._keep_transcription_text = keep_transcription_text - self._epoch = 0 - self._is_training = is_training - # for caching tar files info - self._local_data = local() - self._local_data.tar2info = {} - self._local_data.tar2object = {} - self._place = place - - def process_utterance(self, audio_file, transcript): - """Load, augment, featurize and normalize for speech data. - - :param audio_file: Filepath or file object of audio file. - :type audio_file: str | file - :param transcript: Transcription text. - :type transcript: str - :return: Tuple of audio feature tensor and data of transcription part, - where transcription part could be token ids or text. - :rtype: tuple of (2darray, list) - """ - if isinstance(audio_file, str) and audio_file.startswith('tar:'): - speech_segment = SpeechSegment.from_file( - self._subfile_from_tar(audio_file), transcript) - else: - speech_segment = SpeechSegment.from_file(audio_file, transcript) - self._augmentation_pipeline.transform_audio(speech_segment) - specgram, transcript_part = self._speech_featurizer.featurize( - speech_segment, self._keep_transcription_text) - specgram = self._normalizer.apply(specgram) - return specgram, transcript_part - - def batch_reader_creator(self, - manifest_path, - batch_size, - padding_to=-1, - flatten=False, - sortagrad=False, - shuffle_method="batch_shuffle"): - """ - Batch data reader creator for audio data. Return a callable generator - function to produce batches of data. - - Audio features within one batch will be padded with zeros to have the - same shape, or a user-defined shape. - - :param manifest_path: Filepath of manifest for audio files. - :type manifest_path: str - :param batch_size: Number of instances in a batch. - :type batch_size: int - :param padding_to: If set -1, the maximun shape in the batch - will be used as the target shape for padding. - Otherwise, `padding_to` will be the target shape. - :type padding_to: int - :param flatten: If set True, audio features will be flatten to 1darray. - :type flatten: bool - :param sortagrad: If set True, sort the instances by audio duration - in the first epoch for speed up training. - :type sortagrad: bool - :param shuffle_method: Shuffle method. Options: - '' or None: no shuffle. - 'instance_shuffle': instance-wise shuffle. - 'batch_shuffle': similarly-sized instances are - put into batches, and then - batch-wise shuffle the batches. - For more details, please see - ``_batch_shuffle.__doc__``. - 'batch_shuffle_clipped': 'batch_shuffle' with - head shift and tail - clipping. For more - details, please see - ``_batch_shuffle``. - If sortagrad is True, shuffle is disabled - for the first epoch. - :type shuffle_method: None|str - :return: Batch reader function, producing batches of data when called. - :rtype: callable - """ - - def batch_reader(): - # read manifest - manifest = read_manifest( - manifest_path=manifest_path, - max_duration=self._max_duration, - min_duration=self._min_duration) - - # sort (by duration) or batch-wise shuffle the manifest - if self._epoch == 0 and sortagrad: - manifest.sort(key=lambda x: x["duration"]) - - else: - if shuffle_method == "batch_shuffle": - manifest = self._batch_shuffle( - manifest, batch_size, clipped=False) - elif shuffle_method == "batch_shuffle_clipped": - manifest = self._batch_shuffle( - manifest, batch_size, clipped=True) - elif shuffle_method == "instance_shuffle": - self._rng.shuffle(manifest) - elif shuffle_method == None: - pass - else: - raise ValueError("Unknown shuffle method %s." % - shuffle_method) - # prepare batches - batch = [] - instance_reader = self._instance_reader_creator(manifest) - - for instance in instance_reader(): - batch.append(instance) - if len(batch) == batch_size: - yield self._padding_batch(batch, padding_to, flatten) - batch = [] - if len(batch) >= 1: - yield self._padding_batch(batch, padding_to, flatten) - self._epoch += 1 - - return batch_reader - - @property - def feeding(self): - """Returns data reader's feeding dict. - - :return: Data feeding dict. - :rtype: dict - """ - feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1} - return feeding_dict - - @property - def vocab_size(self): - """Return the vocabulary size. - - :return: Vocabulary size. - :rtype: int - """ - return self._speech_featurizer.vocab_size - - @property - def vocab_list(self): - """Return the vocabulary in list. - - :return: Vocabulary in list. - :rtype: list - """ - return self._speech_featurizer.vocab_list - - def _parse_tar(self, file): - """Parse a tar file to get a tarfile object - and a map containing tarinfoes - """ - result = {} - f = tarfile.open(file) - for tarinfo in f.getmembers(): - result[tarinfo.name] = tarinfo - return f, result - - def _subfile_from_tar(self, file): - """Get subfile object from tar. - - It will return a subfile object from tar file - and cached tar file info for next reading request. - """ - tarpath, filename = file.split(':', 1)[1].split('#', 1) - if 'tar2info' not in self._local_data.__dict__: - self._local_data.tar2info = {} - if 'tar2object' not in self._local_data.__dict__: - self._local_data.tar2object = {} - if tarpath not in self._local_data.tar2info: - object, infoes = self._parse_tar(tarpath) - self._local_data.tar2info[tarpath] = infoes - self._local_data.tar2object[tarpath] = object - return self._local_data.tar2object[tarpath].extractfile( - self._local_data.tar2info[tarpath][filename]) - - def _instance_reader_creator(self, manifest): - """ - Instance reader creator. Create a callable function to produce - instances of data. - - Instance: a tuple of ndarray of audio spectrogram and a list of - token indices for transcript. - """ - - def reader(): - for instance in manifest: - inst = self.process_utterance(instance["audio_filepath"], - instance["text"]) - yield inst - - return reader - - def _padding_batch(self, batch, padding_to=-1, flatten=False): - """ - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one bach. - - If ``padding_to`` is -1, the maximun shape in the batch will be used - as the target shape for padding. Otherwise, `padding_to` will be the - target shape (only refers to the second axis). - - If `flatten` is True, features will be flatten to 1darray. - """ - new_batch = [] - # get target shape - max_length = max([audio.shape[1] for audio, text in batch]) - if padding_to != -1: - if padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be larger " - "than any instance's shape in the batch") - max_length = padding_to - max_text_length = max([len(text) for audio, text in batch]) - # padding - padded_audios = [] - audio_lens = [] - texts, text_lens = [], [] - for audio, text in batch: - padded_audio = np.zeros([audio.shape[0], max_length]) - padded_audio[:, :audio.shape[1]] = audio - if flatten: - padded_audio = padded_audio.flatten() - padded_audios.append(padded_audio) - audio_lens.append(audio.shape[1]) - if self._is_training: - padded_text = np.zeros([max_text_length]) - padded_text[:len(text)] = text - texts.append(padded_text) - else: - texts.append(text) - text_lens.append(len(text)) - - padded_audios = np.array(padded_audios).astype('float32') - audio_lens = np.array(audio_lens).astype('int64') - if self._is_training: - texts = np.array(texts).astype('int32') - text_lens = np.array(text_lens).astype('int64') - return padded_audios, texts, audio_lens, text_lens - - def _batch_shuffle(self, manifest, batch_size, clipped=False): - """Put similarly-sized instances into minibatches for better efficiency - and make a batch-wise shuffle. - - 1. Sort the audio clips by duration. - 2. Generate a random number `k`, k in [0, batch_size). - 3. Randomly shift `k` instances in order to create different batches - for different epochs. Create minibatches. - 4. Shuffle the minibatches. - - :param manifest: Manifest contents. List of dict. - :type manifest: list - :param batch_size: Batch size. This size is also used for generate - a random number for batch shuffle. - :type batch_size: int - :param clipped: Whether to clip the heading (small shift) and trailing - (incomplete batch) instances. - :type clipped: bool - :return: Batch shuffled mainifest. - :rtype: list - """ - manifest.sort(key=lambda x: x["duration"]) - shift_len = self._rng.randint(0, batch_size - 1) - batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size)) - self._rng.shuffle(batch_manifest) - batch_manifest = [item for batch in batch_manifest for item in batch] - if not clipped: - res_len = len(manifest) - shift_len - len(batch_manifest) - batch_manifest.extend(manifest[-res_len:]) - batch_manifest.extend(manifest[0:shift_len]) - return batch_manifest diff --git a/data_utils/dataset.py b/data_utils/dataset.py index e9b581d0b..2de8f87a2 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -226,13 +226,15 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): :rtype: list """ rng = np.random.RandomState(self.epoch) - shift_len = rng.randint(0, batch_size - 1) + # must shift at leat by one + shift_len = rng.randint(1, batch_size - 1) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert (clipped == False) if not clipped: res_len = len(indices) - shift_len - len(batch_indices) + # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) batch_indices.extend(indices[-res_len:]) batch_indices.extend(indices[0:shift_len]) return batch_indices @@ -256,7 +258,9 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): else: raise ValueError("Unknown shuffle method %s." % self._shuffle_method) - assert len(indices) == self.total_size + assert len( + indices + ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" self.epoch += 1 # subsample @@ -362,13 +366,15 @@ class DeepSpeech2BatchSampler(BatchSampler): :rtype: list """ rng = np.random.RandomState(self.epoch) - shift_len = rng.randint(0, batch_size - 1) + # must shift at leat by one + shift_len = rng.randint(1, batch_size - 1) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert (clipped == False) if not clipped: res_len = len(indices) - shift_len - len(batch_indices) + # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) batch_indices.extend(indices[-res_len:]) batch_indices.extend(indices[0:shift_len]) return batch_indices @@ -392,7 +398,9 @@ class DeepSpeech2BatchSampler(BatchSampler): else: raise ValueError("Unknown shuffle method %s." % self._shuffle_method) - assert len(indices) == self.total_size + assert len( + indices + ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" self.epoch += 1 # subsample @@ -516,105 +524,105 @@ class SpeechCollator(): return padded_audios, texts, audio_lens, text_lens -def create_dataloader(manifest_path, - vocab_filepath, - mean_std_filepath, - augmentation_config='{}', - max_duration=float('inf'), - min_duration=0.0, - stride_ms=10.0, - window_ms=20.0, - max_freq=None, - specgram_type='linear', - use_dB_normalization=True, - random_seed=0, - keep_transcription_text=False, - is_training=False, - batch_size=1, - num_workers=0, - sortagrad=False, - shuffle_method=None, - dist=False): - - dataset = DeepSpeech2Dataset( - manifest_path, - vocab_filepath, - mean_std_filepath, - augmentation_config=augmentation_config, - max_duration=max_duration, - min_duration=min_duration, - stride_ms=stride_ms, - window_ms=window_ms, - max_freq=max_freq, - specgram_type=specgram_type, - use_dB_normalization=use_dB_normalization, - random_seed=random_seed, - keep_transcription_text=keep_transcription_text) - - if dist: - batch_sampler = DeepSpeech2DistributedBatchSampler( - dataset, - batch_size, - num_replicas=None, - rank=None, - shuffle=is_training, - drop_last=is_training, - sortagrad=is_training, - shuffle_method=shuffle_method) - else: - batch_sampler = DeepSpeech2BatchSampler( - dataset, - shuffle=is_training, - batch_size=batch_size, - drop_last=is_training, - sortagrad=is_training, - shuffle_method=shuffle_method) - - def padding_batch(batch, padding_to=-1, flatten=False, is_training=True): - """ - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one bach. - - If ``padding_to`` is -1, the maximun shape in the batch will be used - as the target shape for padding. Otherwise, `padding_to` will be the - target shape (only refers to the second axis). - - If `flatten` is True, features will be flatten to 1darray. - """ - new_batch = [] - # get target shape - max_length = max([audio.shape[1] for audio, text in batch]) - if padding_to != -1: - if padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be larger " - "than any instance's shape in the batch") - max_length = padding_to - max_text_length = max([len(text) for audio, text in batch]) - # padding - padded_audios = [] - audio_lens = [] - texts, text_lens = [], [] - for audio, text in batch: - padded_audio = np.zeros([audio.shape[0], max_length]) - padded_audio[:, :audio.shape[1]] = audio - if flatten: - padded_audio = padded_audio.flatten() - padded_audios.append(padded_audio) - audio_lens.append(audio.shape[1]) - padded_text = np.zeros([max_text_length]) - padded_text[:len(text)] = text - texts.append(padded_text) - text_lens.append(len(text)) - - padded_audios = np.array(padded_audios).astype('float32') - audio_lens = np.array(audio_lens).astype('int64') - texts = np.array(texts).astype('int32') - text_lens = np.array(text_lens).astype('int64') - return padded_audios, texts, audio_lens, text_lens - - loader = DataLoader( - dataset, - batch_sampler=batch_sampler, - collate_fn=partial(padding_batch, is_training=is_training), - num_workers=num_workers, ) - return loader +# def create_dataloader(manifest_path, +# vocab_filepath, +# mean_std_filepath, +# augmentation_config='{}', +# max_duration=float('inf'), +# min_duration=0.0, +# stride_ms=10.0, +# window_ms=20.0, +# max_freq=None, +# specgram_type='linear', +# use_dB_normalization=True, +# random_seed=0, +# keep_transcription_text=False, +# is_training=False, +# batch_size=1, +# num_workers=0, +# sortagrad=False, +# shuffle_method=None, +# dist=False): + +# dataset = DeepSpeech2Dataset( +# manifest_path, +# vocab_filepath, +# mean_std_filepath, +# augmentation_config=augmentation_config, +# max_duration=max_duration, +# min_duration=min_duration, +# stride_ms=stride_ms, +# window_ms=window_ms, +# max_freq=max_freq, +# specgram_type=specgram_type, +# use_dB_normalization=use_dB_normalization, +# random_seed=random_seed, +# keep_transcription_text=keep_transcription_text) + +# if dist: +# batch_sampler = DeepSpeech2DistributedBatchSampler( +# dataset, +# batch_size, +# num_replicas=None, +# rank=None, +# shuffle=is_training, +# drop_last=is_training, +# sortagrad=is_training, +# shuffle_method=shuffle_method) +# else: +# batch_sampler = DeepSpeech2BatchSampler( +# dataset, +# shuffle=is_training, +# batch_size=batch_size, +# drop_last=is_training, +# sortagrad=is_training, +# shuffle_method=shuffle_method) + +# def padding_batch(batch, padding_to=-1, flatten=False, is_training=True): +# """ +# Padding audio features with zeros to make them have the same shape (or +# a user-defined shape) within one bach. + +# If ``padding_to`` is -1, the maximun shape in the batch will be used +# as the target shape for padding. Otherwise, `padding_to` will be the +# target shape (only refers to the second axis). + +# If `flatten` is True, features will be flatten to 1darray. +# """ +# new_batch = [] +# # get target shape +# max_length = max([audio.shape[1] for audio, text in batch]) +# if padding_to != -1: +# if padding_to < max_length: +# raise ValueError("If padding_to is not -1, it should be larger " +# "than any instance's shape in the batch") +# max_length = padding_to +# max_text_length = max([len(text) for audio, text in batch]) +# # padding +# padded_audios = [] +# audio_lens = [] +# texts, text_lens = [], [] +# for audio, text in batch: +# padded_audio = np.zeros([audio.shape[0], max_length]) +# padded_audio[:, :audio.shape[1]] = audio +# if flatten: +# padded_audio = padded_audio.flatten() +# padded_audios.append(padded_audio) +# audio_lens.append(audio.shape[1]) +# padded_text = np.zeros([max_text_length]) +# padded_text[:len(text)] = text +# texts.append(padded_text) +# text_lens.append(len(text)) + +# padded_audios = np.array(padded_audios).astype('float32') +# audio_lens = np.array(audio_lens).astype('int64') +# texts = np.array(texts).astype('int32') +# text_lens = np.array(text_lens).astype('int64') +# return padded_audios, texts, audio_lens, text_lens + +# loader = DataLoader( +# dataset, +# batch_sampler=batch_sampler, +# collate_fn=partial(padding_batch, is_training=is_training), +# num_workers=num_workers, ) +# return loader diff --git a/decoders/swig/setup.py b/decoders/swig/setup.py index 0fcb24b50..f6dc048da 100644 --- a/decoders/swig/setup.py +++ b/decoders/swig/setup.py @@ -81,9 +81,8 @@ FILES = glob.glob('kenlm/util/*.cc') \ FILES += glob.glob('openfst-1.6.3/src/lib/*.cc') FILES = [ - fn for fn in FILES - if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith( - 'unittest.cc')) + fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc') + or fn.endswith('unittest.cc')) ] LIBS = ['stdc++'] diff --git a/examples/tiny/local/run_train.sh b/examples/tiny/local/run_train.sh index b197a9fd1..76280b04f 100644 --- a/examples/tiny/local/run_train.sh +++ b/examples/tiny/local/run_train.sh @@ -15,10 +15,10 @@ export FLAGS_sync_nccl_allreduce=0 #--shuffle_method="batch_shuffle_clipped" \ #CUDA_VISIBLE_DEVICES=0,1,2,3 \ -CUDA_VISIBLE_DEVICES=1,2,3 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ python3 -u ${MAIN_ROOT}/train.py \ --device 'gpu' \ ---nproc 1 \ +--nproc 4 \ --config conf/deepspeech2.yaml \ --output ckpt diff --git a/model_utils/config.py b/model_utils/config.py index ead2dbd5f..f4b876045 100644 --- a/model_utils/config.py +++ b/model_utils/config.py @@ -46,8 +46,8 @@ _C.model = CN( num_conv_layers=2, #Number of stacking convolution layers. num_rnn_layers=3, #Number of stacking RNN layers. rnn_layer_size=1024, #RNN layer size (number of RNN cells). - use_gru=True, #Use gru if set True. Use simple rnn if set False. - share_rnn_weights=False #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + use_gru=False, #Use gru if set True. Use simple rnn if set False. + share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. )) _C.training = CN( @@ -62,6 +62,20 @@ _C.training = CN( n_epoch=50, # train epochs )) +_C.decoding = CN( + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) + def get_cfg_defaults(): """Get a yacs CfgNode object with default values for my_project.""" diff --git a/model_utils/model.py b/model_utils/model.py index a8a245036..a14ebde2c 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -43,7 +43,138 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): super().__init__(config, args) - self._ext_scorer = None + + def compute_losses(self, inputs, outputs): + _, texts, _, texts_len = inputs + logits, _, logits_len = outputs + loss = self.criterion(logits, texts, logits_len, texts_len) + return loss + + def read_batch(self): + """Read a batch from the train_loader. + Returns + ------- + List[Tensor] + A batch. + """ + try: + batch = next(self.iterator) + except StopIteration as e: + raise e + return batch + + def train_batch(self): + start = time.time() + batch = self.read_batch() + data_loader_time = time.time() - start + + self.optimizer.clear_grad() + self.model.train() + audio, text, audio_len, text_len = batch + outputs = self.model(audio, text, audio_len, text_len) + loss = self.compute_losses(batch, outputs) + loss.backward() + self.optimizer.step() + iteration_time = time.time() - start + + losses_np = {'train_loss': float(loss)} + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + + msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, + iteration_time) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_np.items()) + self.logger.info(msg) + + if dist.get_rank() == 0: + for k, v in losses_np.items(): + self.visualizer.add_scalar("train/{}".format(k), v, + self.iteration) + + def train(self): + """The training process. + + It includes forward/backward/update and periodical validation and + saving. + """ + self.new_epoch() + while self.epoch <= self.config.training.n_epoch: + try: + self.iteration += 1 + self.train_batch() + + if self.iteration % self.config.training.valid_interval == 0: + self.valid() + + if self.iteration % self.config.training.save_interval == 0: + self.save() + except StopIteration: + self.iteration -= 1 #epoch end, iteration ahead 1 + self.valid() + self.save() + self.new_epoch() + + def compute_metrics(self, inputs, outputs): + pass + + @mp_tools.rank_zero_only + @paddle.no_grad() + def valid(self): + self.model.eval() + valid_losses = defaultdict(list) + for i, batch in enumerate(self.valid_loader): + audio, text, audio_len, text_len = batch + outputs = self.model(audio, text, audio_len, text_len) + loss = self.compute_losses(batch, outputs) + metrics = self.compute_metrics(batch, outputs) + + valid_losses['val_loss'].append(float(loss)) + + # write visual log + valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in valid_losses.items()) + self.logger.info(msg) + + for k, v in valid_losses.items(): + self.visualizer.add_scalar("valid/{}".format(k), v, self.iteration) + + def setup_model(self): + config = self.config + model = DeepSpeech2( + feat_size=self.train_loader.dataset.feature_size, + dict_size=self.train_loader.dataset.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + share_rnn_weights=config.model.share_rnn_weights) + + if self.parallel: + model = paddle.DataParallel(model) + + grad_clip = paddle.nn.ClipGradByGlobalNorm( + config.training.global_grad_clip) + + optimizer = paddle.optimizer.Adam( + learning_rate=config.training.lr, + parameters=model.parameters(), + weight_decay=paddle.regularizer.L2Decay( + config.training.weight_decay), + grad_clip=grad_clip) + + criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size) + + self.model = model + self.optimizer = optimizer + self.criterion = criterion + self.logger.info("Setup model/optimizer/criterion!") def setup_dataloader(self): config = self.config @@ -119,206 +250,91 @@ class DeepSpeech2Trainer(Trainer): collate_fn=collate_fn) self.logger.info("Setup train/valid Dataloader!") - def setup_model(self): - config = self.config - model = DeepSpeech2( - feat_size=self.train_loader.dataset.feature_size, - dict_size=self.train_loader.dataset.vocab_size, - num_conv_layers=config.model.num_conv_layers, - num_rnn_layers=config.model.num_rnn_layers, - rnn_size=config.model.rnn_layer_size, - share_rnn_weights=config.model.share_rnn_weights) - if self.parallel: - model = paddle.DataParallel(model) - - grad_clip = paddle.nn.ClipGradByGlobalNorm( - config.training.global_grad_clip) - - optimizer = paddle.optimizer.Adam( - learning_rate=config.training.lr, - parameters=model.parameters(), - weight_decay=paddle.regularizer.L2Decay( - config.training.weight_decay), - grad_clip=grad_clip) - - criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size) - - self.model = model - self.optimizer = optimizer - self.criterion = criterion - self.logger.info("Setup model/optimizer/criterion!") +class DeepSpeech2Tester(Trainer): + def __init__(self, config, args): + super().__init__(config, args) def compute_losses(self, inputs, outputs): - del inputs - logits, texts, logits_len, texts_len = outputs + _, texts, _, texts_len = inputs + logits, _, logits_len = outputs loss = self.criterion(logits, texts, logits_len, texts_len) return loss - def train_batch(self): - start = time.time() - batch = self.read_batch() - data_loader_time = time.time() - start - - self.optimizer.clear_grad() - self.model.train() - audio, text, audio_len, text_len = batch - outputs = self.model(audio, text, audio_len, text_len) - loss = self.compute_losses(batch, outputs) - loss.backward() - self.optimizer.step() - iteration_time = time.time() - start - - losses_np = {'train_loss': float(loss)} - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - - msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, - iteration_time) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in losses_np.items()) - self.logger.info(msg) - - if dist.get_rank() == 0: - for k, v in losses_np.items(): - self.visualizer.add_scalar("train/{}".format(k), v, - self.iteration) - def compute_metrics(self, inputs, outputs): + _, texts, _, texts_len = inputs + logits, _, logits_len = outputs pass @mp_tools.rank_zero_only @paddle.no_grad() - def valid(self): - valid_losses = defaultdict(list) - for i, batch in enumerate(self.valid_loader): + def test(self): + self.model.eval() + losses = defaultdict(list) + for i, batch in enumerate(self.test_loader): audio, text, audio_len, text_len = batch - outputs = self.model(audio, text, audio_len, text_len) + outputs = self.model.predict(audio, audio_len) loss = self.compute_losses(batch, outputs) metrics = self.compute_metrics(batch, outputs) - valid_losses['val_loss'].append(float(loss)) + losses['test_loss'].append(float(loss)) # write visual log - valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} + losses = {k: np.mean(v) for k, v in losses.items()} # logging - msg = "Valid: " + msg = "Test: " msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in valid_losses.items()) + msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items()) self.logger.info(msg) - for k, v in valid_losses.items(): - self.visualizer.add_scalar("valid/{}".foramt(k), v, self.iteration) - - def infer_batch_probs(self, infer_data): - """Infer the prob matrices for a batch of speech utterances. - :param infer_data: List of utterances to infer, with each utterance - consisting of a tuple of audio features and - transcription text (empty string). - :type infer_data: list - :return: List of 2-D probability matrix, and each consists of prob - vectors for one speech utterancce. - :rtype: List of matrix - """ - self.model.eval() - audio, text, audio_len, text_len = infer_data - _, probs = self.model.predict(audio, audio_len) - return probs - - def decode_batch_greedy(self, probs_split, vocab_list): - """Decode by best path for a batch of probs matrix input. - :param probs_split: List of 2-D probability matrix, and each consists - of prob vectors for one speech utterancce. - :param probs_split: List of matrix - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :return: List of transcription texts. - :rtype: List of str - """ - results = [] - for i, probs in enumerate(probs_split): - output_transcription = ctc_greedy_decoder( - probs_seq=probs, vocabulary=vocab_list) - results.append(output_transcription) - print(results) - return results - - def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, - vocab_list): - """Initialize the external scorer. - :param beam_alpha: Parameter associated with language model. - :type beam_alpha: float - :param beam_beta: Parameter associated with word count. - :type beam_beta: float - :param language_model_path: Filepath for language model. If it is - empty, the external scorer will be set to - None, and the decoding method will be pure - beam search without scorer. - :type language_model_path: str|None - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - """ - if language_model_path != '': - self.logger.info("begin to initialize the external scorer " - "for decoding") - self._ext_scorer = Scorer(beam_alpha, beam_beta, - language_model_path, vocab_list) - lm_char_based = self._ext_scorer.is_character_based() - lm_max_order = self._ext_scorer.get_max_order() - lm_dict_size = self._ext_scorer.get_dict_size() - self.logger.info("language model: " - "is_character_based = %d," % lm_char_based + - " max_order = %d," % lm_max_order + - " dict_size = %d" % lm_dict_size) - self.logger.info("end initializing scorer") - else: - self._ext_scorer = None - self.logger.info("no language model provided, " - "decoding by pure beam search without scorer.") - - def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, - beam_size, cutoff_prob, cutoff_top_n, - vocab_list, num_processes): - """Decode by beam search for a batch of probs matrix input. - :param probs_split: List of 2-D probability matrix, and each consists - of prob vectors for one speech utterancce. - :param probs_split: List of matrix - :param beam_alpha: Parameter associated with language model. - :type beam_alpha: float - :param beam_beta: Parameter associated with word count. - :type beam_beta: float - :param beam_size: Width for Beam search. - :type beam_size: int - :param cutoff_prob: Cutoff probability in pruning, - default 1.0, no pruning. - :type cutoff_prob: float - :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n - characters with highest probs in vocabulary will be - used in beam search, default 40. - :type cutoff_top_n: int - :param vocab_list: List of tokens in the vocabulary, for decoding. - :type vocab_list: list - :param num_processes: Number of processes (CPU) for decoder. - :type num_processes: int - :return: List of transcription texts. - :rtype: List of str - """ - if self._ext_scorer != None: - self._ext_scorer.reset_params(beam_alpha, beam_beta) - # beam search decode - num_processes = min(num_processes, len(probs_split)) - beam_search_results = ctc_beam_search_decoder_batch( - probs_split=probs_split, - vocabulary=vocab_list, - beam_size=beam_size, - num_processes=num_processes, - ext_scoring_func=self._ext_scorer, - cutoff_prob=cutoff_prob, - cutoff_top_n=cutoff_top_n) - - results = [result[0][1] for result in beam_search_results] - return results + for k, v in losses.items(): + self.visualizer.add_scalar("test/{}".format(k), v, self.iteration) + + def setup_model(self): + config = self.config + model = DeepSpeech2( + feat_size=self.train_loader.dataset.feature_size, + dict_size=self.train_loader.dataset.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + share_rnn_weights=config.model.share_rnn_weights) + + if self.parallel: + model = paddle.DataParallel(model) + + criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size) + + self.model = model + self.criterion = criterion + self.logger.info("Setup model/criterion!") + + def setup_dataloader(self): + config = self.config + test_dataset = DeepSpeech2Dataset( + config.data.test_manifest, + config.data.vocab_filepath, + config.data.mean_std_filepath, + augmentation_config="{}", + max_duration=config.data.max_duration, + min_duration=config.data.min_duration, + stride_ms=config.data.stride_ms, + window_ms=config.data.window_ms, + n_fft=config.data.n_fft, + max_freq=config.data.max_freq, + target_sample_rate=config.data.target_sample_rate, + specgram_type=config.data.specgram_type, + use_dB_normalization=config.data.use_dB_normalization, + target_dB=config.data.target_dB, + random_seed=config.data.random_seed, + keep_transcription_text=False) + + self.test_loader = DataLoader( + test_dataset, + batch_size=config.data.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn) + self.logger.info("Setup test Dataloader!") diff --git a/model_utils/network.py b/model_utils/network.py index 14257db6a..e756996d5 100644 --- a/model_utils/network.py +++ b/model_utils/network.py @@ -15,11 +15,17 @@ import math import collections import numpy as np +import logging + import paddle from paddle import nn from paddle.nn import functional as F from paddle.nn import initializer as I +from decoders.swig_wrapper import Scorer +from decoders.swig_wrapper import ctc_greedy_decoder +from decoders.swig_wrapper import ctc_beam_search_decoder_batch + __all__ = ['DeepSpeech2', 'DeepSpeech2Loss'] @@ -497,7 +503,10 @@ class DeepSpeech2(nn.Layer): share_rnn_weights=share_rnn_weights) self.fc = nn.Linear(rnn_size * 2, dict_size + 1) - def predict(self, audio, audio_len): + self.logger = logging.getLogger(__name__) + self._ext_scorer = None + + def infer(self, audio, audio_len): # [B, D, T] -> [B, C=1, D, T] audio = audio.unsqueeze(1) @@ -519,11 +528,6 @@ class DeepSpeech2(nn.Layer): return logits, probs, audio_len - @paddle.no_grad() - def infer(self, audio, audio_len): - _, probs, audio_len = self.predict(audio, audio_len) - return probs - def forward(self, audio, text, audio_len, text_len): """ audio: shape [B, D, T] @@ -531,8 +535,138 @@ class DeepSpeech2(nn.Layer): audio_len: shape [B] text_len: shape [B] """ - logits, _, audio_len = self.predict(audio, audio_len) - return logits, text, audio_len, text_len + return self.infer(audio, audio_len) + + @paddle.no_grad() + def predict(self, audio, audio_len): + """ Model infer """ + return self.infer(audio, audio_len) + + def _decode_batch_greedy(self, probs_split, vocab_list): + """Decode by best path for a batch of probs matrix input. + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :return: List of transcription texts. + :rtype: List of str + """ + results = [] + for i, probs in enumerate(probs_split): + output_transcription = ctc_greedy_decoder( + probs_seq=probs, vocabulary=vocab_list) + results.append(output_transcription) + return results + + def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path, + vocab_list): + """Initialize the external scorer. + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param language_model_path: Filepath for language model. If it is + empty, the external scorer will be set to + None, and the decoding method will be pure + beam search without scorer. + :type language_model_path: str|None + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + """ + # init once + if self._ext_scorer != None: + return + + if language_model_path != '': + self.logger.info("begin to initialize the external scorer " + "for decoding") + self._ext_scorer = Scorer(beam_alpha, beam_beta, + language_model_path, vocab_list) + lm_char_based = self._ext_scorer.is_character_based() + lm_max_order = self._ext_scorer.get_max_order() + lm_dict_size = self._ext_scorer.get_dict_size() + self.logger.info("language model: " + "is_character_based = %d," % lm_char_based + + " max_order = %d," % lm_max_order + + " dict_size = %d" % lm_dict_size) + self.logger.info("end initializing scorer") + else: + self._ext_scorer = None + self.logger.info("no language model provided, " + "decoding by pure beam search without scorer.") + + def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta, + beam_size, cutoff_prob, cutoff_top_n, + vocab_list, num_processes): + """Decode by beam search for a batch of probs matrix input. + :param probs_split: List of 2-D probability matrix, and each consists + of prob vectors for one speech utterancce. + :param probs_split: List of matrix + :param beam_alpha: Parameter associated with language model. + :type beam_alpha: float + :param beam_beta: Parameter associated with word count. + :type beam_beta: float + :param beam_size: Width for Beam search. + :type beam_size: int + :param cutoff_prob: Cutoff probability in pruning, + default 1.0, no pruning. + :type cutoff_prob: float + :param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n + characters with highest probs in vocabulary will be + used in beam search, default 40. + :type cutoff_top_n: int + :param vocab_list: List of tokens in the vocabulary, for decoding. + :type vocab_list: list + :param num_processes: Number of processes (CPU) for decoder. + :type num_processes: int + :return: List of transcription texts. + :rtype: List of str + """ + if self._ext_scorer != None: + self._ext_scorer.reset_params(beam_alpha, beam_beta) + + # beam search decode + num_processes = min(num_processes, len(probs_split)) + beam_search_results = ctc_beam_search_decoder_batch( + probs_split=probs_split, + vocabulary=vocab_list, + beam_size=beam_size, + num_processes=num_processes, + ext_scoring_func=self._ext_scorer, + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n) + + results = [result[0][1] for result in beam_search_results] + return results + + def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list, + decoding_method): + if decoding_method == "ctc_beam_search": + self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, + vocab_list) + + @paddle.no_grad() + def decode(self, audio, audio_len, vocab_list, decoding_method, + lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, + cutoff_top_n, num_processes): + _, probs, _ = self.predict(audio, audio_len) + if decoding_method == "ctc_greedy": + result_transcripts = self._decode_batch_greedy( + probs_split=probs, vocab_list=vocab_list) + elif decoding_method == "ctc_beam_search": + result_transcripts = self._decode_batch_beam_search( + probs_split=probs, + beam_alpha=beam_alpha, + beam_beta=beam_beta, + beam_size=beam_size, + cutoff_prob=cutoff_prob, + cutoff_top_n=cutoff_top_n, + vocab_list=vocab_list, + num_processes=num_processes) + else: + raise ValueError(f"Not support: {decoding_method}") + return result_transcripts class DeepSpeech2Loss(nn.Layer): diff --git a/model_utils/network2_test.py b/model_utils/network_test.py similarity index 100% rename from model_utils/network2_test.py rename to model_utils/network_test.py diff --git a/model_utils/trainer.py b/model_utils/trainer.py deleted file mode 100644 index 90a2bfb85..000000000 --- a/model_utils/trainer.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -import logging -from pathlib import Path -import numpy as np -import paddle -from paddle import distributed as dist -from paddle.io import DataLoader, DistributedBatchSampler -from tensorboardX import SummaryWriter -from collections import defaultdict - -import parakeet -from parakeet.utils import checkpoint, mp_tools - -__all__ = ["ExperimentBase"] - - -class ExperimentBase(object): - """ - An experiment template in order to structure the training code and take - care of saving, loading, logging, visualization stuffs. It's intended to - be flexible and simple. - - So it only handles output directory (create directory for the output, - create a checkpoint directory, dump the config in use and create - visualizer and logger) in a standard way without enforcing any - input-output protocols to the model and dataloader. It leaves the main - part for the user to implement their own (setup the model, criterion, - optimizer, define a training step, define a validation function and - customize all the text and visual logs). - It does not save too much boilerplate code. The users still have to write - the forward/backward/update mannually, but they are free to add - non-standard behaviors if needed. - We have some conventions to follow. - 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and - ``valid_loader``, ``config`` and ``args`` attributes. - 2. The config should have a ``training`` field, which has - ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is - used as the trigger to invoke validation, checkpointing and stop of the - experiment. - 3. There are four methods, namely ``train_batch``, ``valid``, - ``setup_model`` and ``setup_dataloader`` that should be implemented. - Feel free to add/overwrite other methods and standalone functions if you - need. - - Parameters - ---------- - config: yacs.config.CfgNode - The configuration used for the experiment. - - args: argparse.Namespace - The parsed command line arguments. - Examples - -------- - >>> def main_sp(config, args): - >>> exp = Experiment(config, args) - >>> exp.setup() - >>> exp.run() - >>> - >>> config = get_cfg_defaults() - >>> parser = default_argument_parser() - >>> args = parser.parse_args() - >>> if args.config: - >>> config.merge_from_file(args.config) - >>> if args.opts: - >>> config.merge_from_list(args.opts) - >>> config.freeze() - >>> - >>> if args.nprocs > 1 and args.device == "gpu": - >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) - >>> else: - >>> main_sp(config, args) - """ - - def __init__(self, config, args): - self.config = config - self.args = args - - def setup(self): - """Setup the experiment. - """ - paddle.set_device(self.args.device) - if self.parallel: - self.init_parallel() - - self.setup_output_dir() - self.dump_config() - self.setup_visualizer() - self.setup_logger() - self.setup_checkpointer() - - self.setup_dataloader() - self.setup_model() - - self.iteration = 0 - self.epoch = 0 - - @property - def parallel(self): - """A flag indicating whether the experiment should run with - multiprocessing. - """ - return self.args.device == "gpu" and self.args.nprocs > 1 - - def init_parallel(self): - """Init environment for multiprocess training. - """ - dist.init_parallel_env() - - def save(self): - """Save checkpoint (model parameters and optimizer states). - """ - checkpoint.save_parameters(self.checkpoint_dir, self.iteration, - self.model, self.optimizer) - - def load_or_resume(self): - """Resume from latest checkpoint at checkpoints in the output - directory or load a specified checkpoint. - - If ``args.checkpoint_path`` is not None, load the checkpoint, else - resume training. - """ - iteration = checkpoint.load_parameters( - self.model, - self.optimizer, - checkpoint_dir=self.checkpoint_dir, - checkpoint_path=self.args.checkpoint_path) - self.iteration = iteration - - def read_batch(self): - """Read a batch from the train_loader. - Returns - ------- - List[Tensor] - A batch. - """ - try: - batch = next(self.iterator) - except StopIteration: - self.new_epoch() - batch = next(self.iterator) - return batch - - def new_epoch(self): - """Reset the train loader and increment ``epoch``. - """ - self.epoch += 1 - if self.parallel: - self.train_loader.batch_sampler.set_epoch(self.epoch) - self.iterator = iter(self.train_loader) - - def train(self): - """The training process. - - It includes forward/backward/update and periodical validation and - saving. - """ - self.new_epoch() - while self.iteration < self.config.training.max_iteration: - self.iteration += 1 - self.train_batch() - - if self.iteration % self.config.training.valid_interval == 0: - self.valid() - - if self.iteration % self.config.training.save_interval == 0: - self.save() - - def run(self): - """The routine of the experiment after setup. This method is intended - to be used by the user. - """ - self.load_or_resume() - try: - self.train() - except KeyboardInterrupt: - self.save() - exit(-1) - - @mp_tools.rank_zero_only - def setup_output_dir(self): - """Create a directory used for output. - """ - # output dir - output_dir = Path(self.args.output).expanduser() - output_dir.mkdir(parents=True, exist_ok=True) - - self.output_dir = output_dir - - @mp_tools.rank_zero_only - def setup_checkpointer(self): - """Create a directory used to save checkpoints into. - - It is "checkpoints" inside the output directory. - """ - # checkpoint dir - checkpoint_dir = self.output_dir / "checkpoints" - checkpoint_dir.mkdir(exist_ok=True) - - self.checkpoint_dir = checkpoint_dir - - @mp_tools.rank_zero_only - def setup_visualizer(self): - """Initialize a visualizer to log the experiment. - - The visual log is saved in the output directory. - - Notes - ------ - Only the main process has a visualizer with it. Use multiple - visualizers in multiprocess to write to a same log file may cause - unexpected behaviors. - """ - # visualizer - visualizer = SummaryWriter(logdir=str(self.output_dir)) - - self.visualizer = visualizer - - def setup_logger(self): - """Initialize a text logger to log the experiment. - - Each process has its own text logger. The logging message is write to - the standard output and a text file named ``worker_n.log`` in the - output directory, where ``n`` means the rank of the process. - """ - logger = logging.getLogger(__name__) - logger.setLevel("INFO") - logger.addHandler(logging.StreamHandler()) - log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank()) - logger.addHandler(logging.FileHandler(str(log_file))) - - self.logger = logger - - @mp_tools.rank_zero_only - def dump_config(self): - """Save the configuration used for this experiment. - - It is saved in to ``config.yaml`` in the output directory at the - beginning of the experiment. - """ - with open(self.output_dir / "config.yaml", 'wt') as f: - print(self.config, file=f) - - def train_batch(self): - """The training loop. A subclass should implement this method. - """ - raise NotImplementedError("train_batch should be implemented.") - - @mp_tools.rank_zero_only - @paddle.no_grad() - def valid(self): - """The validation. A subclass should implement this method. - """ - raise NotImplementedError("valid should be implemented.") - - def setup_model(self): - """Setup model, criterion and optimizer, etc. A subclass should - implement this method. - """ - raise NotImplementedError("setup_model should be implemented.") - - def setup_dataloader(self): - """Setup training dataloader and validation dataloader. A subclass - should implement this method. - """ - raise NotImplementedError("setup_dataloader should be implemented.") \ No newline at end of file diff --git a/test.py b/test.py index d3b601e98..6fac5e4c3 100644 --- a/test.py +++ b/test.py @@ -13,119 +13,38 @@ # limitations under the License. """Evaluation for DeepSpeech2 model.""" +import io +import logging import argparse import functools -import paddle.fluid as fluid -from data_utils.data import DataGenerator -from model_utils.model import DeepSpeech2Model -from model_utils.model_check import check_cuda, check_version + +from paddle import distributed as dist + +from utils.utility import print_arguments +from training.cli import default_argument_parser + +from model_utils.config import get_cfg_defaults +from model_utils.model import DeepSpeech2Trainer as Trainer from utils.error_rate import char_errors, word_errors -from utils.utility import add_arguments, print_arguments - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('batch_size', int, 128, "Minibatch size.") -add_arg('beam_size', int, 500, "Beam search width.") -add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.") -add_arg('num_conv_layers', int, 2, "# of convolution layers.") -add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") -add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") -add_arg('alpha', float, 2.5, "Coef of LM for beam search.") -add_arg('beta', float, 0.3, "Coef of WC for beam search.") -add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.") -add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.") -add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") -add_arg('use_gpu', bool, True, "Use GPU or not.") -add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " - "bi-directional RNNs. Not for GRU.") -add_arg('test_manifest', str, - 'data/librispeech/manifest.test-clean', - "Filepath of manifest to evaluate.") -add_arg('mean_std_path', str, - 'data/librispeech/mean_std.npz', - "Filepath of normalizer's mean & std.") -add_arg('vocab_path', str, - 'data/librispeech/vocab.txt', - "Filepath of vocabulary.") -add_arg('model_path', str, - './checkpoints/libri/step_final', - "If None, the training starts from scratch, " - "otherwise, it resumes from the pre-trained model.") -add_arg('lang_model_path', str, - 'models/lm/common_crawl_00.prune01111.trie.klm', - "Filepath for language model.") -add_arg('decoding_method', str, - 'ctc_beam_search', - "Decoding method. Options: ctc_beam_search, ctc_greedy", - choices = ['ctc_beam_search', 'ctc_greedy']) -add_arg('error_rate_type', str, - 'wer', - "Error rate type for evaluation.", - choices=['wer', 'cer']) -add_arg('specgram_type', str, - 'linear', - "Audio feature type. Options: linear, mfcc.", - choices=['linear', 'mfcc']) -# yapf: disable -args = parser.parse_args() def evaluate(): """Evaluate on whole test data for DeepSpeech2.""" - # check if set use_gpu=True in paddlepaddle cpu version - check_cuda(args.use_gpu) - # check if paddlepaddle version is satisfied - check_version() - - if args.use_gpu: - place = fluid.CUDAPlace(0) - else: - place = fluid.CPUPlace() - - data_generator = DataGenerator( - vocab_filepath=args.vocab_path, - mean_std_filepath=args.mean_std_path, - augmentation_config='{}', - specgram_type=args.specgram_type, - keep_transcription_text=True, - place = place, - is_training = False) - batch_reader = data_generator.batch_reader_creator( - manifest_path=args.test_manifest, - batch_size=args.batch_size, - sortagrad=False, - shuffle_method=None) - - ds2_model = DeepSpeech2Model( - vocab_size=data_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_layer_size=args.rnn_layer_size, - use_gru=args.use_gru, - share_rnn_weights=args.share_rnn_weights, - place=place, - init_from_pretrained_model=args.model_path) - # decoders only accept string encoded in utf-8 vocab_list = [chars for chars in data_generator.vocab_list] - if args.decoding_method == "ctc_beam_search": - ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path, - vocab_list) errors_func = char_errors if args.error_rate_type == 'cer' else word_errors errors_sum, len_refs, num_ins = 0.0, 0, 0 ds2_model.logger.info("start evaluation ...") + for infer_data in batch_reader(): probs_split = ds2_model.infer_batch_probs( - infer_data=infer_data, - feeding_dict=data_generator.feeding) + infer_data=infer_data, feeding_dict=data_generator.feeding) if args.decoding_method == "ctc_greedy": result_transcripts = ds2_model.decode_batch_greedy( - probs_split=probs_split, - vocab_list=vocab_list) + probs_split=probs_split, vocab_list=vocab_list) else: result_transcripts = ds2_model.decode_batch_beam_search( probs_split=probs_split, @@ -136,6 +55,7 @@ def evaluate(): cutoff_top_n=args.cutoff_top_n, vocab_list=vocab_list, num_processes=args.num_proc_bsearch) + target_transcripts = infer_data[1] for target, result in zip(target_transcripts, result_transcripts): @@ -145,15 +65,38 @@ def evaluate(): num_ins += 1 print("Error rate [%s] (%d/?) = %f" % (args.error_rate_type, num_ins, errors_sum / len_refs)) + print("Final error rate [%s] (%d/%d) = %f" % (args.error_rate_type, num_ins, num_ins, errors_sum / len_refs)) ds2_model.logger.info("finish evaluation") -def main(): - print_arguments(args) - evaluate() +def main_sp(config, args): + exp = Trainer(config, args) + exp.setup() + exp.run() + + +def main(config, args): + main_sp(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + args = parser.parse_args() + print_arguments(args) -if __name__ == '__main__': - main() + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/training/trainer.py b/training/trainer.py index e1b898df5..9ffc42813 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -121,6 +121,7 @@ class Trainer(): """ dist.init_parallel_env() + @mp_tools.rank_zero_only def save(self): """Save checkpoint (model parameters and optimizer states). """ @@ -190,8 +191,9 @@ class Trainer(): except KeyboardInterrupt: self.save() exit(-1) + finally: + self.destory() - @mp_tools.rank_zero_only def setup_output_dir(self): """Create a directory used for output. """ @@ -201,7 +203,6 @@ class Trainer(): self.output_dir = output_dir - @mp_tools.rank_zero_only def setup_checkpointer(self): """Create a directory used to save checkpoints into. @@ -213,6 +214,11 @@ class Trainer(): self.checkpoint_dir = checkpoint_dir + @mp_tools.rank_zero_only + def destory(self): + # https://github.com/pytorch/fairseq/issues/2357 + self.visualizer.close() + @mp_tools.rank_zero_only def setup_visualizer(self): """Initialize a visualizer to log the experiment. diff --git a/utils/checkpoint.py b/utils/checkpoint.py index 5a09f20a1..e0c6938c9 100644 --- a/utils/checkpoint.py +++ b/utils/checkpoint.py @@ -14,6 +14,7 @@ import os import time +import logging import numpy as np import paddle from paddle import distributed as dist @@ -22,6 +23,9 @@ from paddle.optimizer import Optimizer from utils import mp_tools +logger = logging.getLogger(__name__) +logger.setLevel("INFO") + __all__ = ["load_parameters", "save_parameters"] @@ -94,15 +98,15 @@ def load_parameters(model, params_path = checkpoint_path + ".pdparams" model_dict = paddle.load(params_path) model.set_state_dict(model_dict) - print( + logger.info( "[checkpoint] Rank {}: loaded model from {}".format(rank, params_path)) optimizer_path = checkpoint_path + ".pdopt" if optimizer and os.path.isfile(optimizer_path): optimizer_dict = paddle.load(optimizer_path) optimizer.set_state_dict(optimizer_dict) - print("[checkpoint] Rank {}: loaded optimizer state from {}".format( - rank, optimizer_path)) + logger.info("[checkpoint] Rank {}: loaded optimizer state from {}". + format(rank, optimizer_path)) return iteration @@ -124,12 +128,13 @@ def save_parameters(checkpoint_dir, iteration, model, optimizer=None): model_dict = model.state_dict() params_path = checkpoint_path + ".pdparams" paddle.save(model_dict, params_path) - print("[checkpoint] Saved model to {}".format(params_path)) + logger.info("[checkpoint] Saved model to {}".format(params_path)) if optimizer: opt_dict = optimizer.state_dict() optimizer_path = checkpoint_path + ".pdopt" paddle.save(opt_dict, optimizer_path) - print("[checkpoint] Saved optimzier state to {}".format(optimizer_path)) + logger.info( + "[checkpoint] Saved optimzier state to {}".format(optimizer_path)) _save_checkpoint(checkpoint_dir, iteration) diff --git a/utils/mp_tools.py b/utils/mp_tools.py index 0daa62af2..9c3c3d548 100644 --- a/utils/mp_tools.py +++ b/utils/mp_tools.py @@ -20,13 +20,12 @@ __all__ = ["rank_zero_only"] def rank_zero_only(func): - rank = dist.get_rank() - @wraps(func) def wrapper(*args, **kwargs): + rank = dist.get_rank() if rank != 0: return result = func(*args, **kwargs) return result - return wrapper \ No newline at end of file + return wrapper