fix rank_zero_only decreator error close tensorboard when train over add decoding config and codepull/522/head
parent
d79ae3824a
commit
c18162ca90
@ -1,373 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains data generator for orgnaizing various audio data preprocessing
|
|
||||||
pipeline and offering data reader interface of PaddlePaddle requirements.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import random
|
|
||||||
import tarfile
|
|
||||||
import multiprocessing
|
|
||||||
import numpy as np
|
|
||||||
import paddle.fluid as fluid
|
|
||||||
from threading import local
|
|
||||||
from data_utils.utility import read_manifest
|
|
||||||
from data_utils.augmentor.augmentation import AugmentationPipeline
|
|
||||||
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
|
|
||||||
from data_utils.speech import SpeechSegment
|
|
||||||
from data_utils.normalizer import FeatureNormalizer
|
|
||||||
|
|
||||||
__all__ = ['DataGenerator']
|
|
||||||
|
|
||||||
|
|
||||||
class DataGenerator():
|
|
||||||
"""
|
|
||||||
DataGenerator provides basic audio data preprocessing pipeline, and offers
|
|
||||||
data reader interfaces of PaddlePaddle requirements.
|
|
||||||
|
|
||||||
:param vocab_filepath: Vocabulary filepath for indexing tokenized
|
|
||||||
transcripts.
|
|
||||||
:type vocab_filepath: str
|
|
||||||
:param mean_std_filepath: File containing the pre-computed mean and stddev.
|
|
||||||
:type mean_std_filepath: None|str
|
|
||||||
:param augmentation_config: Augmentation configuration in json string.
|
|
||||||
Details see AugmentationPipeline.__doc__.
|
|
||||||
:type augmentation_config: str
|
|
||||||
:param max_duration: Audio with duration (in seconds) greater than
|
|
||||||
this will be discarded.
|
|
||||||
:type max_duration: float
|
|
||||||
:param min_duration: Audio with duration (in seconds) smaller than
|
|
||||||
this will be discarded.
|
|
||||||
:type min_duration: float
|
|
||||||
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
|
||||||
:type stride_ms: float
|
|
||||||
:param window_ms: Window size (in milliseconds) for generating frames.
|
|
||||||
:type window_ms: float
|
|
||||||
:param max_freq: Used when specgram_type is 'linear', only FFT bins
|
|
||||||
corresponding to frequencies between [0, max_freq] are
|
|
||||||
returned.
|
|
||||||
:types max_freq: None|float
|
|
||||||
:param specgram_type: Specgram feature type. Options: 'linear'.
|
|
||||||
:type specgram_type: str
|
|
||||||
:param use_dB_normalization: Whether to normalize the audio to -20 dB
|
|
||||||
before extracting the features.
|
|
||||||
:type use_dB_normalization: bool
|
|
||||||
:param random_seed: Random seed.
|
|
||||||
:type random_seed: int
|
|
||||||
:param keep_transcription_text: If set to True, transcription text will
|
|
||||||
be passed forward directly without
|
|
||||||
converting to index sequence.
|
|
||||||
:type keep_transcription_text: bool
|
|
||||||
:param place: The place to run the program.
|
|
||||||
:type place: CPUPlace or CUDAPlace
|
|
||||||
:param is_training: If set to True, generate text data for training,
|
|
||||||
otherwise, generate text data for infer.
|
|
||||||
:type is_training: bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
vocab_filepath,
|
|
||||||
mean_std_filepath,
|
|
||||||
augmentation_config='{}',
|
|
||||||
max_duration=float('inf'),
|
|
||||||
min_duration=0.0,
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=20.0,
|
|
||||||
max_freq=None,
|
|
||||||
specgram_type='linear',
|
|
||||||
use_dB_normalization=True,
|
|
||||||
random_seed=0,
|
|
||||||
keep_transcription_text=False,
|
|
||||||
place=fluid.CPUPlace(),
|
|
||||||
is_training=True):
|
|
||||||
self._max_duration = max_duration
|
|
||||||
self._min_duration = min_duration
|
|
||||||
self._normalizer = FeatureNormalizer(mean_std_filepath)
|
|
||||||
self._augmentation_pipeline = AugmentationPipeline(
|
|
||||||
augmentation_config=augmentation_config, random_seed=random_seed)
|
|
||||||
self._speech_featurizer = SpeechFeaturizer(
|
|
||||||
vocab_filepath=vocab_filepath,
|
|
||||||
specgram_type=specgram_type,
|
|
||||||
stride_ms=stride_ms,
|
|
||||||
window_ms=window_ms,
|
|
||||||
max_freq=max_freq,
|
|
||||||
use_dB_normalization=use_dB_normalization)
|
|
||||||
self._rng = random.Random(random_seed)
|
|
||||||
self._keep_transcription_text = keep_transcription_text
|
|
||||||
self._epoch = 0
|
|
||||||
self._is_training = is_training
|
|
||||||
# for caching tar files info
|
|
||||||
self._local_data = local()
|
|
||||||
self._local_data.tar2info = {}
|
|
||||||
self._local_data.tar2object = {}
|
|
||||||
self._place = place
|
|
||||||
|
|
||||||
def process_utterance(self, audio_file, transcript):
|
|
||||||
"""Load, augment, featurize and normalize for speech data.
|
|
||||||
|
|
||||||
:param audio_file: Filepath or file object of audio file.
|
|
||||||
:type audio_file: str | file
|
|
||||||
:param transcript: Transcription text.
|
|
||||||
:type transcript: str
|
|
||||||
:return: Tuple of audio feature tensor and data of transcription part,
|
|
||||||
where transcription part could be token ids or text.
|
|
||||||
:rtype: tuple of (2darray, list)
|
|
||||||
"""
|
|
||||||
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
|
|
||||||
speech_segment = SpeechSegment.from_file(
|
|
||||||
self._subfile_from_tar(audio_file), transcript)
|
|
||||||
else:
|
|
||||||
speech_segment = SpeechSegment.from_file(audio_file, transcript)
|
|
||||||
self._augmentation_pipeline.transform_audio(speech_segment)
|
|
||||||
specgram, transcript_part = self._speech_featurizer.featurize(
|
|
||||||
speech_segment, self._keep_transcription_text)
|
|
||||||
specgram = self._normalizer.apply(specgram)
|
|
||||||
return specgram, transcript_part
|
|
||||||
|
|
||||||
def batch_reader_creator(self,
|
|
||||||
manifest_path,
|
|
||||||
batch_size,
|
|
||||||
padding_to=-1,
|
|
||||||
flatten=False,
|
|
||||||
sortagrad=False,
|
|
||||||
shuffle_method="batch_shuffle"):
|
|
||||||
"""
|
|
||||||
Batch data reader creator for audio data. Return a callable generator
|
|
||||||
function to produce batches of data.
|
|
||||||
|
|
||||||
Audio features within one batch will be padded with zeros to have the
|
|
||||||
same shape, or a user-defined shape.
|
|
||||||
|
|
||||||
:param manifest_path: Filepath of manifest for audio files.
|
|
||||||
:type manifest_path: str
|
|
||||||
:param batch_size: Number of instances in a batch.
|
|
||||||
:type batch_size: int
|
|
||||||
:param padding_to: If set -1, the maximun shape in the batch
|
|
||||||
will be used as the target shape for padding.
|
|
||||||
Otherwise, `padding_to` will be the target shape.
|
|
||||||
:type padding_to: int
|
|
||||||
:param flatten: If set True, audio features will be flatten to 1darray.
|
|
||||||
:type flatten: bool
|
|
||||||
:param sortagrad: If set True, sort the instances by audio duration
|
|
||||||
in the first epoch for speed up training.
|
|
||||||
:type sortagrad: bool
|
|
||||||
:param shuffle_method: Shuffle method. Options:
|
|
||||||
'' or None: no shuffle.
|
|
||||||
'instance_shuffle': instance-wise shuffle.
|
|
||||||
'batch_shuffle': similarly-sized instances are
|
|
||||||
put into batches, and then
|
|
||||||
batch-wise shuffle the batches.
|
|
||||||
For more details, please see
|
|
||||||
``_batch_shuffle.__doc__``.
|
|
||||||
'batch_shuffle_clipped': 'batch_shuffle' with
|
|
||||||
head shift and tail
|
|
||||||
clipping. For more
|
|
||||||
details, please see
|
|
||||||
``_batch_shuffle``.
|
|
||||||
If sortagrad is True, shuffle is disabled
|
|
||||||
for the first epoch.
|
|
||||||
:type shuffle_method: None|str
|
|
||||||
:return: Batch reader function, producing batches of data when called.
|
|
||||||
:rtype: callable
|
|
||||||
"""
|
|
||||||
|
|
||||||
def batch_reader():
|
|
||||||
# read manifest
|
|
||||||
manifest = read_manifest(
|
|
||||||
manifest_path=manifest_path,
|
|
||||||
max_duration=self._max_duration,
|
|
||||||
min_duration=self._min_duration)
|
|
||||||
|
|
||||||
# sort (by duration) or batch-wise shuffle the manifest
|
|
||||||
if self._epoch == 0 and sortagrad:
|
|
||||||
manifest.sort(key=lambda x: x["duration"])
|
|
||||||
|
|
||||||
else:
|
|
||||||
if shuffle_method == "batch_shuffle":
|
|
||||||
manifest = self._batch_shuffle(
|
|
||||||
manifest, batch_size, clipped=False)
|
|
||||||
elif shuffle_method == "batch_shuffle_clipped":
|
|
||||||
manifest = self._batch_shuffle(
|
|
||||||
manifest, batch_size, clipped=True)
|
|
||||||
elif shuffle_method == "instance_shuffle":
|
|
||||||
self._rng.shuffle(manifest)
|
|
||||||
elif shuffle_method == None:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown shuffle method %s." %
|
|
||||||
shuffle_method)
|
|
||||||
# prepare batches
|
|
||||||
batch = []
|
|
||||||
instance_reader = self._instance_reader_creator(manifest)
|
|
||||||
|
|
||||||
for instance in instance_reader():
|
|
||||||
batch.append(instance)
|
|
||||||
if len(batch) == batch_size:
|
|
||||||
yield self._padding_batch(batch, padding_to, flatten)
|
|
||||||
batch = []
|
|
||||||
if len(batch) >= 1:
|
|
||||||
yield self._padding_batch(batch, padding_to, flatten)
|
|
||||||
self._epoch += 1
|
|
||||||
|
|
||||||
return batch_reader
|
|
||||||
|
|
||||||
@property
|
|
||||||
def feeding(self):
|
|
||||||
"""Returns data reader's feeding dict.
|
|
||||||
|
|
||||||
:return: Data feeding dict.
|
|
||||||
:rtype: dict
|
|
||||||
"""
|
|
||||||
feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
|
|
||||||
return feeding_dict
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_size(self):
|
|
||||||
"""Return the vocabulary size.
|
|
||||||
|
|
||||||
:return: Vocabulary size.
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
return self._speech_featurizer.vocab_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_list(self):
|
|
||||||
"""Return the vocabulary in list.
|
|
||||||
|
|
||||||
:return: Vocabulary in list.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
return self._speech_featurizer.vocab_list
|
|
||||||
|
|
||||||
def _parse_tar(self, file):
|
|
||||||
"""Parse a tar file to get a tarfile object
|
|
||||||
and a map containing tarinfoes
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
f = tarfile.open(file)
|
|
||||||
for tarinfo in f.getmembers():
|
|
||||||
result[tarinfo.name] = tarinfo
|
|
||||||
return f, result
|
|
||||||
|
|
||||||
def _subfile_from_tar(self, file):
|
|
||||||
"""Get subfile object from tar.
|
|
||||||
|
|
||||||
It will return a subfile object from tar file
|
|
||||||
and cached tar file info for next reading request.
|
|
||||||
"""
|
|
||||||
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
|
||||||
if 'tar2info' not in self._local_data.__dict__:
|
|
||||||
self._local_data.tar2info = {}
|
|
||||||
if 'tar2object' not in self._local_data.__dict__:
|
|
||||||
self._local_data.tar2object = {}
|
|
||||||
if tarpath not in self._local_data.tar2info:
|
|
||||||
object, infoes = self._parse_tar(tarpath)
|
|
||||||
self._local_data.tar2info[tarpath] = infoes
|
|
||||||
self._local_data.tar2object[tarpath] = object
|
|
||||||
return self._local_data.tar2object[tarpath].extractfile(
|
|
||||||
self._local_data.tar2info[tarpath][filename])
|
|
||||||
|
|
||||||
def _instance_reader_creator(self, manifest):
|
|
||||||
"""
|
|
||||||
Instance reader creator. Create a callable function to produce
|
|
||||||
instances of data.
|
|
||||||
|
|
||||||
Instance: a tuple of ndarray of audio spectrogram and a list of
|
|
||||||
token indices for transcript.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def reader():
|
|
||||||
for instance in manifest:
|
|
||||||
inst = self.process_utterance(instance["audio_filepath"],
|
|
||||||
instance["text"])
|
|
||||||
yield inst
|
|
||||||
|
|
||||||
return reader
|
|
||||||
|
|
||||||
def _padding_batch(self, batch, padding_to=-1, flatten=False):
|
|
||||||
"""
|
|
||||||
Padding audio features with zeros to make them have the same shape (or
|
|
||||||
a user-defined shape) within one bach.
|
|
||||||
|
|
||||||
If ``padding_to`` is -1, the maximun shape in the batch will be used
|
|
||||||
as the target shape for padding. Otherwise, `padding_to` will be the
|
|
||||||
target shape (only refers to the second axis).
|
|
||||||
|
|
||||||
If `flatten` is True, features will be flatten to 1darray.
|
|
||||||
"""
|
|
||||||
new_batch = []
|
|
||||||
# get target shape
|
|
||||||
max_length = max([audio.shape[1] for audio, text in batch])
|
|
||||||
if padding_to != -1:
|
|
||||||
if padding_to < max_length:
|
|
||||||
raise ValueError("If padding_to is not -1, it should be larger "
|
|
||||||
"than any instance's shape in the batch")
|
|
||||||
max_length = padding_to
|
|
||||||
max_text_length = max([len(text) for audio, text in batch])
|
|
||||||
# padding
|
|
||||||
padded_audios = []
|
|
||||||
audio_lens = []
|
|
||||||
texts, text_lens = [], []
|
|
||||||
for audio, text in batch:
|
|
||||||
padded_audio = np.zeros([audio.shape[0], max_length])
|
|
||||||
padded_audio[:, :audio.shape[1]] = audio
|
|
||||||
if flatten:
|
|
||||||
padded_audio = padded_audio.flatten()
|
|
||||||
padded_audios.append(padded_audio)
|
|
||||||
audio_lens.append(audio.shape[1])
|
|
||||||
if self._is_training:
|
|
||||||
padded_text = np.zeros([max_text_length])
|
|
||||||
padded_text[:len(text)] = text
|
|
||||||
texts.append(padded_text)
|
|
||||||
else:
|
|
||||||
texts.append(text)
|
|
||||||
text_lens.append(len(text))
|
|
||||||
|
|
||||||
padded_audios = np.array(padded_audios).astype('float32')
|
|
||||||
audio_lens = np.array(audio_lens).astype('int64')
|
|
||||||
if self._is_training:
|
|
||||||
texts = np.array(texts).astype('int32')
|
|
||||||
text_lens = np.array(text_lens).astype('int64')
|
|
||||||
return padded_audios, texts, audio_lens, text_lens
|
|
||||||
|
|
||||||
def _batch_shuffle(self, manifest, batch_size, clipped=False):
|
|
||||||
"""Put similarly-sized instances into minibatches for better efficiency
|
|
||||||
and make a batch-wise shuffle.
|
|
||||||
|
|
||||||
1. Sort the audio clips by duration.
|
|
||||||
2. Generate a random number `k`, k in [0, batch_size).
|
|
||||||
3. Randomly shift `k` instances in order to create different batches
|
|
||||||
for different epochs. Create minibatches.
|
|
||||||
4. Shuffle the minibatches.
|
|
||||||
|
|
||||||
:param manifest: Manifest contents. List of dict.
|
|
||||||
:type manifest: list
|
|
||||||
:param batch_size: Batch size. This size is also used for generate
|
|
||||||
a random number for batch shuffle.
|
|
||||||
:type batch_size: int
|
|
||||||
:param clipped: Whether to clip the heading (small shift) and trailing
|
|
||||||
(incomplete batch) instances.
|
|
||||||
:type clipped: bool
|
|
||||||
:return: Batch shuffled mainifest.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
manifest.sort(key=lambda x: x["duration"])
|
|
||||||
shift_len = self._rng.randint(0, batch_size - 1)
|
|
||||||
batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size))
|
|
||||||
self._rng.shuffle(batch_manifest)
|
|
||||||
batch_manifest = [item for batch in batch_manifest for item in batch]
|
|
||||||
if not clipped:
|
|
||||||
res_len = len(manifest) - shift_len - len(batch_manifest)
|
|
||||||
batch_manifest.extend(manifest[-res_len:])
|
|
||||||
batch_manifest.extend(manifest[0:shift_len])
|
|
||||||
return batch_manifest
|
|
@ -1,279 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
import numpy as np
|
|
||||||
import paddle
|
|
||||||
from paddle import distributed as dist
|
|
||||||
from paddle.io import DataLoader, DistributedBatchSampler
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import parakeet
|
|
||||||
from parakeet.utils import checkpoint, mp_tools
|
|
||||||
|
|
||||||
__all__ = ["ExperimentBase"]
|
|
||||||
|
|
||||||
|
|
||||||
class ExperimentBase(object):
|
|
||||||
"""
|
|
||||||
An experiment template in order to structure the training code and take
|
|
||||||
care of saving, loading, logging, visualization stuffs. It's intended to
|
|
||||||
be flexible and simple.
|
|
||||||
|
|
||||||
So it only handles output directory (create directory for the output,
|
|
||||||
create a checkpoint directory, dump the config in use and create
|
|
||||||
visualizer and logger) in a standard way without enforcing any
|
|
||||||
input-output protocols to the model and dataloader. It leaves the main
|
|
||||||
part for the user to implement their own (setup the model, criterion,
|
|
||||||
optimizer, define a training step, define a validation function and
|
|
||||||
customize all the text and visual logs).
|
|
||||||
It does not save too much boilerplate code. The users still have to write
|
|
||||||
the forward/backward/update mannually, but they are free to add
|
|
||||||
non-standard behaviors if needed.
|
|
||||||
We have some conventions to follow.
|
|
||||||
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
|
|
||||||
``valid_loader``, ``config`` and ``args`` attributes.
|
|
||||||
2. The config should have a ``training`` field, which has
|
|
||||||
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
|
|
||||||
used as the trigger to invoke validation, checkpointing and stop of the
|
|
||||||
experiment.
|
|
||||||
3. There are four methods, namely ``train_batch``, ``valid``,
|
|
||||||
``setup_model`` and ``setup_dataloader`` that should be implemented.
|
|
||||||
Feel free to add/overwrite other methods and standalone functions if you
|
|
||||||
need.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config: yacs.config.CfgNode
|
|
||||||
The configuration used for the experiment.
|
|
||||||
|
|
||||||
args: argparse.Namespace
|
|
||||||
The parsed command line arguments.
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> def main_sp(config, args):
|
|
||||||
>>> exp = Experiment(config, args)
|
|
||||||
>>> exp.setup()
|
|
||||||
>>> exp.run()
|
|
||||||
>>>
|
|
||||||
>>> config = get_cfg_defaults()
|
|
||||||
>>> parser = default_argument_parser()
|
|
||||||
>>> args = parser.parse_args()
|
|
||||||
>>> if args.config:
|
|
||||||
>>> config.merge_from_file(args.config)
|
|
||||||
>>> if args.opts:
|
|
||||||
>>> config.merge_from_list(args.opts)
|
|
||||||
>>> config.freeze()
|
|
||||||
>>>
|
|
||||||
>>> if args.nprocs > 1 and args.device == "gpu":
|
|
||||||
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
|
|
||||||
>>> else:
|
|
||||||
>>> main_sp(config, args)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config, args):
|
|
||||||
self.config = config
|
|
||||||
self.args = args
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
"""Setup the experiment.
|
|
||||||
"""
|
|
||||||
paddle.set_device(self.args.device)
|
|
||||||
if self.parallel:
|
|
||||||
self.init_parallel()
|
|
||||||
|
|
||||||
self.setup_output_dir()
|
|
||||||
self.dump_config()
|
|
||||||
self.setup_visualizer()
|
|
||||||
self.setup_logger()
|
|
||||||
self.setup_checkpointer()
|
|
||||||
|
|
||||||
self.setup_dataloader()
|
|
||||||
self.setup_model()
|
|
||||||
|
|
||||||
self.iteration = 0
|
|
||||||
self.epoch = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parallel(self):
|
|
||||||
"""A flag indicating whether the experiment should run with
|
|
||||||
multiprocessing.
|
|
||||||
"""
|
|
||||||
return self.args.device == "gpu" and self.args.nprocs > 1
|
|
||||||
|
|
||||||
def init_parallel(self):
|
|
||||||
"""Init environment for multiprocess training.
|
|
||||||
"""
|
|
||||||
dist.init_parallel_env()
|
|
||||||
|
|
||||||
def save(self):
|
|
||||||
"""Save checkpoint (model parameters and optimizer states).
|
|
||||||
"""
|
|
||||||
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
|
|
||||||
self.model, self.optimizer)
|
|
||||||
|
|
||||||
def load_or_resume(self):
|
|
||||||
"""Resume from latest checkpoint at checkpoints in the output
|
|
||||||
directory or load a specified checkpoint.
|
|
||||||
|
|
||||||
If ``args.checkpoint_path`` is not None, load the checkpoint, else
|
|
||||||
resume training.
|
|
||||||
"""
|
|
||||||
iteration = checkpoint.load_parameters(
|
|
||||||
self.model,
|
|
||||||
self.optimizer,
|
|
||||||
checkpoint_dir=self.checkpoint_dir,
|
|
||||||
checkpoint_path=self.args.checkpoint_path)
|
|
||||||
self.iteration = iteration
|
|
||||||
|
|
||||||
def read_batch(self):
|
|
||||||
"""Read a batch from the train_loader.
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[Tensor]
|
|
||||||
A batch.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
batch = next(self.iterator)
|
|
||||||
except StopIteration:
|
|
||||||
self.new_epoch()
|
|
||||||
batch = next(self.iterator)
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def new_epoch(self):
|
|
||||||
"""Reset the train loader and increment ``epoch``.
|
|
||||||
"""
|
|
||||||
self.epoch += 1
|
|
||||||
if self.parallel:
|
|
||||||
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
|
||||||
self.iterator = iter(self.train_loader)
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
"""The training process.
|
|
||||||
|
|
||||||
It includes forward/backward/update and periodical validation and
|
|
||||||
saving.
|
|
||||||
"""
|
|
||||||
self.new_epoch()
|
|
||||||
while self.iteration < self.config.training.max_iteration:
|
|
||||||
self.iteration += 1
|
|
||||||
self.train_batch()
|
|
||||||
|
|
||||||
if self.iteration % self.config.training.valid_interval == 0:
|
|
||||||
self.valid()
|
|
||||||
|
|
||||||
if self.iteration % self.config.training.save_interval == 0:
|
|
||||||
self.save()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
"""The routine of the experiment after setup. This method is intended
|
|
||||||
to be used by the user.
|
|
||||||
"""
|
|
||||||
self.load_or_resume()
|
|
||||||
try:
|
|
||||||
self.train()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self.save()
|
|
||||||
exit(-1)
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def setup_output_dir(self):
|
|
||||||
"""Create a directory used for output.
|
|
||||||
"""
|
|
||||||
# output dir
|
|
||||||
output_dir = Path(self.args.output).expanduser()
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
self.output_dir = output_dir
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def setup_checkpointer(self):
|
|
||||||
"""Create a directory used to save checkpoints into.
|
|
||||||
|
|
||||||
It is "checkpoints" inside the output directory.
|
|
||||||
"""
|
|
||||||
# checkpoint dir
|
|
||||||
checkpoint_dir = self.output_dir / "checkpoints"
|
|
||||||
checkpoint_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
self.checkpoint_dir = checkpoint_dir
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def setup_visualizer(self):
|
|
||||||
"""Initialize a visualizer to log the experiment.
|
|
||||||
|
|
||||||
The visual log is saved in the output directory.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
------
|
|
||||||
Only the main process has a visualizer with it. Use multiple
|
|
||||||
visualizers in multiprocess to write to a same log file may cause
|
|
||||||
unexpected behaviors.
|
|
||||||
"""
|
|
||||||
# visualizer
|
|
||||||
visualizer = SummaryWriter(logdir=str(self.output_dir))
|
|
||||||
|
|
||||||
self.visualizer = visualizer
|
|
||||||
|
|
||||||
def setup_logger(self):
|
|
||||||
"""Initialize a text logger to log the experiment.
|
|
||||||
|
|
||||||
Each process has its own text logger. The logging message is write to
|
|
||||||
the standard output and a text file named ``worker_n.log`` in the
|
|
||||||
output directory, where ``n`` means the rank of the process.
|
|
||||||
"""
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.setLevel("INFO")
|
|
||||||
logger.addHandler(logging.StreamHandler())
|
|
||||||
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
|
|
||||||
logger.addHandler(logging.FileHandler(str(log_file)))
|
|
||||||
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def dump_config(self):
|
|
||||||
"""Save the configuration used for this experiment.
|
|
||||||
|
|
||||||
It is saved in to ``config.yaml`` in the output directory at the
|
|
||||||
beginning of the experiment.
|
|
||||||
"""
|
|
||||||
with open(self.output_dir / "config.yaml", 'wt') as f:
|
|
||||||
print(self.config, file=f)
|
|
||||||
|
|
||||||
def train_batch(self):
|
|
||||||
"""The training loop. A subclass should implement this method.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("train_batch should be implemented.")
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
@paddle.no_grad()
|
|
||||||
def valid(self):
|
|
||||||
"""The validation. A subclass should implement this method.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("valid should be implemented.")
|
|
||||||
|
|
||||||
def setup_model(self):
|
|
||||||
"""Setup model, criterion and optimizer, etc. A subclass should
|
|
||||||
implement this method.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("setup_model should be implemented.")
|
|
||||||
|
|
||||||
def setup_dataloader(self):
|
|
||||||
"""Setup training dataloader and validation dataloader. A subclass
|
|
||||||
should implement this method.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("setup_dataloader should be implemented.")
|
|
Loading…
Reference in new issue