fix dataset batch shuffle shift start from 1

fix rank_zero_only decreator error
close tensorboard when train over
add decoding config and code
pull/522/head
Hui Zhang 5 years ago
parent d79ae3824a
commit c18162ca90

@ -1,373 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains data generator for orgnaizing various audio data preprocessing
pipeline and offering data reader interface of PaddlePaddle requirements.
"""
import random
import tarfile
import multiprocessing
import numpy as np
import paddle.fluid as fluid
from threading import local
from data_utils.utility import read_manifest
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.speech import SpeechSegment
from data_utils.normalizer import FeatureNormalizer
__all__ = ['DataGenerator']
class DataGenerator():
"""
DataGenerator provides basic audio data preprocessing pipeline, and offers
data reader interfaces of PaddlePaddle requirements.
:param vocab_filepath: Vocabulary filepath for indexing tokenized
transcripts.
:type vocab_filepath: str
:param mean_std_filepath: File containing the pre-computed mean and stddev.
:type mean_std_filepath: None|str
:param augmentation_config: Augmentation configuration in json string.
Details see AugmentationPipeline.__doc__.
:type augmentation_config: str
:param max_duration: Audio with duration (in seconds) greater than
this will be discarded.
:type max_duration: float
:param min_duration: Audio with duration (in seconds) smaller than
this will be discarded.
:type min_duration: float
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param max_freq: Used when specgram_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are
returned.
:types max_freq: None|float
:param specgram_type: Specgram feature type. Options: 'linear'.
:type specgram_type: str
:param use_dB_normalization: Whether to normalize the audio to -20 dB
before extracting the features.
:type use_dB_normalization: bool
:param random_seed: Random seed.
:type random_seed: int
:param keep_transcription_text: If set to True, transcription text will
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
:param place: The place to run the program.
:type place: CPUPlace or CUDAPlace
:param is_training: If set to True, generate text data for training,
otherwise, generate text data for infer.
:type is_training: bool
"""
def __init__(self,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type='linear',
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False,
place=fluid.CPUPlace(),
is_training=True):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer(
vocab_filepath=vocab_filepath,
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
use_dB_normalization=use_dB_normalization)
self._rng = random.Random(random_seed)
self._keep_transcription_text = keep_transcription_text
self._epoch = 0
self._is_training = is_training
# for caching tar files info
self._local_data = local()
self._local_data.tar2info = {}
self._local_data.tar2object = {}
self._place = place
def process_utterance(self, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of transcription part,
where transcription part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram)
return specgram, transcript_part
def batch_reader_creator(self,
manifest_path,
batch_size,
padding_to=-1,
flatten=False,
sortagrad=False,
shuffle_method="batch_shuffle"):
"""
Batch data reader creator for audio data. Return a callable generator
function to produce batches of data.
Audio features within one batch will be padded with zeros to have the
same shape, or a user-defined shape.
:param manifest_path: Filepath of manifest for audio files.
:type manifest_path: str
:param batch_size: Number of instances in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun shape in the batch
will be used as the target shape for padding.
Otherwise, `padding_to` will be the target shape.
:type padding_to: int
:param flatten: If set True, audio features will be flatten to 1darray.
:type flatten: bool
:param sortagrad: If set True, sort the instances by audio duration
in the first epoch for speed up training.
:type sortagrad: bool
:param shuffle_method: Shuffle method. Options:
'' or None: no shuffle.
'instance_shuffle': instance-wise shuffle.
'batch_shuffle': similarly-sized instances are
put into batches, and then
batch-wise shuffle the batches.
For more details, please see
``_batch_shuffle.__doc__``.
'batch_shuffle_clipped': 'batch_shuffle' with
head shift and tail
clipping. For more
details, please see
``_batch_shuffle``.
If sortagrad is True, shuffle is disabled
for the first epoch.
:type shuffle_method: None|str
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def batch_reader():
# read manifest
manifest = read_manifest(
manifest_path=manifest_path,
max_duration=self._max_duration,
min_duration=self._min_duration)
# sort (by duration) or batch-wise shuffle the manifest
if self._epoch == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
else:
if shuffle_method == "batch_shuffle":
manifest = self._batch_shuffle(
manifest, batch_size, clipped=False)
elif shuffle_method == "batch_shuffle_clipped":
manifest = self._batch_shuffle(
manifest, batch_size, clipped=True)
elif shuffle_method == "instance_shuffle":
self._rng.shuffle(manifest)
elif shuffle_method == None:
pass
else:
raise ValueError("Unknown shuffle method %s." %
shuffle_method)
# prepare batches
batch = []
instance_reader = self._instance_reader_creator(manifest)
for instance in instance_reader():
batch.append(instance)
if len(batch) == batch_size:
yield self._padding_batch(batch, padding_to, flatten)
batch = []
if len(batch) >= 1:
yield self._padding_batch(batch, padding_to, flatten)
self._epoch += 1
return batch_reader
@property
def feeding(self):
"""Returns data reader's feeding dict.
:return: Data feeding dict.
:rtype: dict
"""
feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
return feeding_dict
@property
def vocab_size(self):
"""Return the vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
"""Return the vocabulary in list.
:return: Vocabulary in list.
:rtype: list
"""
return self._speech_featurizer.vocab_list
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def _instance_reader_creator(self, manifest):
"""
Instance reader creator. Create a callable function to produce
instances of data.
Instance: a tuple of ndarray of audio spectrogram and a list of
token indices for transcript.
"""
def reader():
for instance in manifest:
inst = self.process_utterance(instance["audio_filepath"],
instance["text"])
yield inst
return reader
def _padding_batch(self, batch, padding_to=-1, flatten=False):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
If `flatten` is True, features will be flatten to 1darray.
"""
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be larger "
"than any instance's shape in the batch")
max_length = padding_to
max_text_length = max([len(text) for audio, text in batch])
# padding
padded_audios = []
audio_lens = []
texts, text_lens = [], []
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
padded_audios.append(padded_audio)
audio_lens.append(audio.shape[1])
if self._is_training:
padded_text = np.zeros([max_text_length])
padded_text[:len(text)] = text
texts.append(padded_text)
else:
texts.append(text)
text_lens.append(len(text))
padded_audios = np.array(padded_audios).astype('float32')
audio_lens = np.array(audio_lens).astype('int64')
if self._is_training:
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
def _batch_shuffle(self, manifest, batch_size, clipped=False):
"""Put similarly-sized instances into minibatches for better efficiency
and make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly shift `k` instances in order to create different batches
for different epochs. Create minibatches.
4. Shuffle the minibatches.
:param manifest: Manifest contents. List of dict.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:param clipped: Whether to clip the heading (small shift) and trailing
(incomplete batch) instances.
:type clipped: bool
:return: Batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self._rng.randint(0, batch_size - 1)
batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size))
self._rng.shuffle(batch_manifest)
batch_manifest = [item for batch in batch_manifest for item in batch]
if not clipped:
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest

@ -226,13 +226,15 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
:rtype: list
"""
rng = np.random.RandomState(self.epoch)
shift_len = rng.randint(0, batch_size - 1)
# must shift at leat by one
shift_len = rng.randint(1, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert (clipped == False)
if not clipped:
res_len = len(indices) - shift_len - len(batch_indices)
# when res_len is 0, will return whole list, len(List[-0:]) = len(List[:])
batch_indices.extend(indices[-res_len:])
batch_indices.extend(indices[0:shift_len])
return batch_indices
@ -256,7 +258,9 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
else:
raise ValueError("Unknown shuffle method %s." %
self._shuffle_method)
assert len(indices) == self.total_size
assert len(
indices
) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}"
self.epoch += 1
# subsample
@ -362,13 +366,15 @@ class DeepSpeech2BatchSampler(BatchSampler):
:rtype: list
"""
rng = np.random.RandomState(self.epoch)
shift_len = rng.randint(0, batch_size - 1)
# must shift at leat by one
shift_len = rng.randint(1, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert (clipped == False)
if not clipped:
res_len = len(indices) - shift_len - len(batch_indices)
# when res_len is 0, will return whole list, len(List[-0:]) = len(List[:])
batch_indices.extend(indices[-res_len:])
batch_indices.extend(indices[0:shift_len])
return batch_indices
@ -392,7 +398,9 @@ class DeepSpeech2BatchSampler(BatchSampler):
else:
raise ValueError("Unknown shuffle method %s." %
self._shuffle_method)
assert len(indices) == self.total_size
assert len(
indices
) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}"
self.epoch += 1
# subsample
@ -516,105 +524,105 @@ class SpeechCollator():
return padded_audios, texts, audio_lens, text_lens
def create_dataloader(manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
specgram_type='linear',
use_dB_normalization=True,
random_seed=0,
keep_transcription_text=False,
is_training=False,
batch_size=1,
num_workers=0,
sortagrad=False,
shuffle_method=None,
dist=False):
dataset = DeepSpeech2Dataset(
manifest_path,
vocab_filepath,
mean_std_filepath,
augmentation_config=augmentation_config,
max_duration=max_duration,
min_duration=min_duration,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
specgram_type=specgram_type,
use_dB_normalization=use_dB_normalization,
random_seed=random_seed,
keep_transcription_text=keep_transcription_text)
if dist:
batch_sampler = DeepSpeech2DistributedBatchSampler(
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=is_training,
drop_last=is_training,
sortagrad=is_training,
shuffle_method=shuffle_method)
else:
batch_sampler = DeepSpeech2BatchSampler(
dataset,
shuffle=is_training,
batch_size=batch_size,
drop_last=is_training,
sortagrad=is_training,
shuffle_method=shuffle_method)
def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):
"""
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
If ``padding_to`` is -1, the maximun shape in the batch will be used
as the target shape for padding. Otherwise, `padding_to` will be the
target shape (only refers to the second axis).
If `flatten` is True, features will be flatten to 1darray.
"""
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be larger "
"than any instance's shape in the batch")
max_length = padding_to
max_text_length = max([len(text) for audio, text in batch])
# padding
padded_audios = []
audio_lens = []
texts, text_lens = [], []
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
padded_audios.append(padded_audio)
audio_lens.append(audio.shape[1])
padded_text = np.zeros([max_text_length])
padded_text[:len(text)] = text
texts.append(padded_text)
text_lens.append(len(text))
padded_audios = np.array(padded_audios).astype('float32')
audio_lens = np.array(audio_lens).astype('int64')
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=partial(padding_batch, is_training=is_training),
num_workers=num_workers, )
return loader
# def create_dataloader(manifest_path,
# vocab_filepath,
# mean_std_filepath,
# augmentation_config='{}',
# max_duration=float('inf'),
# min_duration=0.0,
# stride_ms=10.0,
# window_ms=20.0,
# max_freq=None,
# specgram_type='linear',
# use_dB_normalization=True,
# random_seed=0,
# keep_transcription_text=False,
# is_training=False,
# batch_size=1,
# num_workers=0,
# sortagrad=False,
# shuffle_method=None,
# dist=False):
# dataset = DeepSpeech2Dataset(
# manifest_path,
# vocab_filepath,
# mean_std_filepath,
# augmentation_config=augmentation_config,
# max_duration=max_duration,
# min_duration=min_duration,
# stride_ms=stride_ms,
# window_ms=window_ms,
# max_freq=max_freq,
# specgram_type=specgram_type,
# use_dB_normalization=use_dB_normalization,
# random_seed=random_seed,
# keep_transcription_text=keep_transcription_text)
# if dist:
# batch_sampler = DeepSpeech2DistributedBatchSampler(
# dataset,
# batch_size,
# num_replicas=None,
# rank=None,
# shuffle=is_training,
# drop_last=is_training,
# sortagrad=is_training,
# shuffle_method=shuffle_method)
# else:
# batch_sampler = DeepSpeech2BatchSampler(
# dataset,
# shuffle=is_training,
# batch_size=batch_size,
# drop_last=is_training,
# sortagrad=is_training,
# shuffle_method=shuffle_method)
# def padding_batch(batch, padding_to=-1, flatten=False, is_training=True):
# """
# Padding audio features with zeros to make them have the same shape (or
# a user-defined shape) within one bach.
# If ``padding_to`` is -1, the maximun shape in the batch will be used
# as the target shape for padding. Otherwise, `padding_to` will be the
# target shape (only refers to the second axis).
# If `flatten` is True, features will be flatten to 1darray.
# """
# new_batch = []
# # get target shape
# max_length = max([audio.shape[1] for audio, text in batch])
# if padding_to != -1:
# if padding_to < max_length:
# raise ValueError("If padding_to is not -1, it should be larger "
# "than any instance's shape in the batch")
# max_length = padding_to
# max_text_length = max([len(text) for audio, text in batch])
# # padding
# padded_audios = []
# audio_lens = []
# texts, text_lens = [], []
# for audio, text in batch:
# padded_audio = np.zeros([audio.shape[0], max_length])
# padded_audio[:, :audio.shape[1]] = audio
# if flatten:
# padded_audio = padded_audio.flatten()
# padded_audios.append(padded_audio)
# audio_lens.append(audio.shape[1])
# padded_text = np.zeros([max_text_length])
# padded_text[:len(text)] = text
# texts.append(padded_text)
# text_lens.append(len(text))
# padded_audios = np.array(padded_audios).astype('float32')
# audio_lens = np.array(audio_lens).astype('int64')
# texts = np.array(texts).astype('int32')
# text_lens = np.array(text_lens).astype('int64')
# return padded_audios, texts, audio_lens, text_lens
# loader = DataLoader(
# dataset,
# batch_sampler=batch_sampler,
# collate_fn=partial(padding_batch, is_training=is_training),
# num_workers=num_workers, )
# return loader

@ -81,9 +81,8 @@ FILES = glob.glob('kenlm/util/*.cc') \
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
FILES = [
fn for fn in FILES
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
'unittest.cc'))
fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc')
or fn.endswith('unittest.cc'))
]
LIBS = ['stdc++']

@ -15,10 +15,10 @@ export FLAGS_sync_nccl_allreduce=0
#--shuffle_method="batch_shuffle_clipped" \
#CUDA_VISIBLE_DEVICES=0,1,2,3 \
CUDA_VISIBLE_DEVICES=1,2,3 \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python3 -u ${MAIN_ROOT}/train.py \
--device 'gpu' \
--nproc 1 \
--nproc 4 \
--config conf/deepspeech2.yaml \
--output ckpt

@ -46,8 +46,8 @@ _C.model = CN(
num_conv_layers=2, #Number of stacking convolution layers.
num_rnn_layers=3, #Number of stacking RNN layers.
rnn_layer_size=1024, #RNN layer size (number of RNN cells).
use_gru=True, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=False #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
use_gru=False, #Use gru if set True. Use simple rnn if set False.
share_rnn_weights=True #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported.
))
_C.training = CN(
@ -62,6 +62,20 @@ _C.training = CN(
n_epoch=50, # train epochs
))
_C.decoding = CN(
dict(
alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width.
batch_size=128, # decoding batch size
))
def get_cfg_defaults():
"""Get a yacs CfgNode object with default values for my_project."""

@ -43,7 +43,138 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch
class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self._ext_scorer = None
def compute_losses(self, inputs, outputs):
_, texts, _, texts_len = inputs
logits, _, logits_len = outputs
loss = self.criterion(logits, texts, logits_len, texts_len)
return loss
def read_batch(self):
"""Read a batch from the train_loader.
Returns
-------
List[Tensor]
A batch.
"""
try:
batch = next(self.iterator)
except StopIteration as e:
raise e
return batch
def train_batch(self):
start = time.time()
batch = self.read_batch()
data_loader_time = time.time() - start
self.optimizer.clear_grad()
self.model.train()
audio, text, audio_len, text_len = batch
outputs = self.model(audio, text, audio_len, text_len)
loss = self.compute_losses(batch, outputs)
loss.backward()
self.optimizer.step()
iteration_time = time.time() - start
losses_np = {'train_loss': float(loss)}
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
iteration_time)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
self.logger.info(msg)
if dist.get_rank() == 0:
for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
def train(self):
"""The training process.
It includes forward/backward/update and periodical validation and
saving.
"""
self.new_epoch()
while self.epoch <= self.config.training.n_epoch:
try:
self.iteration += 1
self.train_batch()
if self.iteration % self.config.training.valid_interval == 0:
self.valid()
if self.iteration % self.config.training.save_interval == 0:
self.save()
except StopIteration:
self.iteration -= 1 #epoch end, iteration ahead 1
self.valid()
self.save()
self.new_epoch()
def compute_metrics(self, inputs, outputs):
pass
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
self.model.eval()
valid_losses = defaultdict(list)
for i, batch in enumerate(self.valid_loader):
audio, text, audio_len, text_len = batch
outputs = self.model(audio, text, audio_len, text_len)
loss = self.compute_losses(batch, outputs)
metrics = self.compute_metrics(batch, outputs)
valid_losses['val_loss'].append(float(loss))
# write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
# logging
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items())
self.logger.info(msg)
for k, v in valid_losses.items():
self.visualizer.add_scalar("valid/{}".format(k), v, self.iteration)
def setup_model(self):
config = self.config
model = DeepSpeech2(
feat_size=self.train_loader.dataset.feature_size,
dict_size=self.train_loader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
share_rnn_weights=config.model.share_rnn_weights)
if self.parallel:
model = paddle.DataParallel(model)
grad_clip = paddle.nn.ClipGradByGlobalNorm(
config.training.global_grad_clip)
optimizer = paddle.optimizer.Adam(
learning_rate=config.training.lr,
parameters=model.parameters(),
weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay),
grad_clip=grad_clip)
criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size)
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.logger.info("Setup model/optimizer/criterion!")
def setup_dataloader(self):
config = self.config
@ -119,206 +250,91 @@ class DeepSpeech2Trainer(Trainer):
collate_fn=collate_fn)
self.logger.info("Setup train/valid Dataloader!")
def setup_model(self):
config = self.config
model = DeepSpeech2(
feat_size=self.train_loader.dataset.feature_size,
dict_size=self.train_loader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
share_rnn_weights=config.model.share_rnn_weights)
if self.parallel:
model = paddle.DataParallel(model)
grad_clip = paddle.nn.ClipGradByGlobalNorm(
config.training.global_grad_clip)
optimizer = paddle.optimizer.Adam(
learning_rate=config.training.lr,
parameters=model.parameters(),
weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay),
grad_clip=grad_clip)
criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size)
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.logger.info("Setup model/optimizer/criterion!")
class DeepSpeech2Tester(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
def compute_losses(self, inputs, outputs):
del inputs
logits, texts, logits_len, texts_len = outputs
_, texts, _, texts_len = inputs
logits, _, logits_len = outputs
loss = self.criterion(logits, texts, logits_len, texts_len)
return loss
def train_batch(self):
start = time.time()
batch = self.read_batch()
data_loader_time = time.time() - start
self.optimizer.clear_grad()
self.model.train()
audio, text, audio_len, text_len = batch
outputs = self.model(audio, text, audio_len, text_len)
loss = self.compute_losses(batch, outputs)
loss.backward()
self.optimizer.step()
iteration_time = time.time() - start
losses_np = {'train_loss': float(loss)}
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
iteration_time)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
self.logger.info(msg)
if dist.get_rank() == 0:
for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
def compute_metrics(self, inputs, outputs):
_, texts, _, texts_len = inputs
logits, _, logits_len = outputs
pass
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
valid_losses = defaultdict(list)
for i, batch in enumerate(self.valid_loader):
def test(self):
self.model.eval()
losses = defaultdict(list)
for i, batch in enumerate(self.test_loader):
audio, text, audio_len, text_len = batch
outputs = self.model(audio, text, audio_len, text_len)
outputs = self.model.predict(audio, audio_len)
loss = self.compute_losses(batch, outputs)
metrics = self.compute_metrics(batch, outputs)
valid_losses['val_loss'].append(float(loss))
losses['test_loss'].append(float(loss))
# write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
losses = {k: np.mean(v) for k, v in losses.items()}
# logging
msg = "Valid: "
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items())
msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses.items())
self.logger.info(msg)
for k, v in valid_losses.items():
self.visualizer.add_scalar("valid/{}".foramt(k), v, self.iteration)
def infer_batch_probs(self, infer_data):
"""Infer the prob matrices for a batch of speech utterances.
:param infer_data: List of utterances to infer, with each utterance
consisting of a tuple of audio features and
transcription text (empty string).
:type infer_data: list
:return: List of 2-D probability matrix, and each consists of prob
vectors for one speech utterancce.
:rtype: List of matrix
"""
self.model.eval()
audio, text, audio_len, text_len = infer_data
_, probs = self.model.predict(audio, audio_len)
return probs
def decode_batch_greedy(self, probs_split, vocab_list):
"""Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:return: List of transcription texts.
:rtype: List of str
"""
results = []
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list)
results.append(output_transcription)
print(results)
return results
def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
vocab_list):
"""Initialize the external scorer.
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param language_model_path: Filepath for language model. If it is
empty, the external scorer will be set to
None, and the decoding method will be pure
beam search without scorer.
:type language_model_path: str|None
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
"""
if language_model_path != '':
self.logger.info("begin to initialize the external scorer "
"for decoding")
self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path, vocab_list)
lm_char_based = self._ext_scorer.is_character_based()
lm_max_order = self._ext_scorer.get_max_order()
lm_dict_size = self._ext_scorer.get_dict_size()
self.logger.info("language model: "
"is_character_based = %d," % lm_char_based +
" max_order = %d," % lm_max_order +
" dict_size = %d" % lm_dict_size)
self.logger.info("end initializing scorer")
else:
self._ext_scorer = None
self.logger.info("no language model provided, "
"decoding by pure beam search without scorer.")
def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
vocab_list, num_processes):
"""Decode by beam search for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param beam_size: Width for Beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:return: List of transcription texts.
:rtype: List of str
"""
if self._ext_scorer != None:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
# beam search decode
num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=beam_size,
num_processes=num_processes,
ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n)
results = [result[0][1] for result in beam_search_results]
return results
for k, v in losses.items():
self.visualizer.add_scalar("test/{}".format(k), v, self.iteration)
def setup_model(self):
config = self.config
model = DeepSpeech2(
feat_size=self.train_loader.dataset.feature_size,
dict_size=self.train_loader.dataset.vocab_size,
num_conv_layers=config.model.num_conv_layers,
num_rnn_layers=config.model.num_rnn_layers,
rnn_size=config.model.rnn_layer_size,
share_rnn_weights=config.model.share_rnn_weights)
if self.parallel:
model = paddle.DataParallel(model)
criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size)
self.model = model
self.criterion = criterion
self.logger.info("Setup model/criterion!")
def setup_dataloader(self):
config = self.config
test_dataset = DeepSpeech2Dataset(
config.data.test_manifest,
config.data.vocab_filepath,
config.data.mean_std_filepath,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
n_fft=config.data.n_fft,
max_freq=config.data.max_freq,
target_sample_rate=config.data.target_sample_rate,
specgram_type=config.data.specgram_type,
use_dB_normalization=config.data.use_dB_normalization,
target_dB=config.data.target_dB,
random_seed=config.data.random_seed,
keep_transcription_text=False)
self.test_loader = DataLoader(
test_dataset,
batch_size=config.data.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn)
self.logger.info("Setup test Dataloader!")

@ -15,11 +15,17 @@
import math
import collections
import numpy as np
import logging
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_greedy_decoder
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
__all__ = ['DeepSpeech2', 'DeepSpeech2Loss']
@ -497,7 +503,10 @@ class DeepSpeech2(nn.Layer):
share_rnn_weights=share_rnn_weights)
self.fc = nn.Linear(rnn_size * 2, dict_size + 1)
def predict(self, audio, audio_len):
self.logger = logging.getLogger(__name__)
self._ext_scorer = None
def infer(self, audio, audio_len):
# [B, D, T] -> [B, C=1, D, T]
audio = audio.unsqueeze(1)
@ -519,11 +528,6 @@ class DeepSpeech2(nn.Layer):
return logits, probs, audio_len
@paddle.no_grad()
def infer(self, audio, audio_len):
_, probs, audio_len = self.predict(audio, audio_len)
return probs
def forward(self, audio, text, audio_len, text_len):
"""
audio: shape [B, D, T]
@ -531,8 +535,138 @@ class DeepSpeech2(nn.Layer):
audio_len: shape [B]
text_len: shape [B]
"""
logits, _, audio_len = self.predict(audio, audio_len)
return logits, text, audio_len, text_len
return self.infer(audio, audio_len)
@paddle.no_grad()
def predict(self, audio, audio_len):
""" Model infer """
return self.infer(audio, audio_len)
def _decode_batch_greedy(self, probs_split, vocab_list):
"""Decode by best path for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:return: List of transcription texts.
:rtype: List of str
"""
results = []
for i, probs in enumerate(probs_split):
output_transcription = ctc_greedy_decoder(
probs_seq=probs, vocabulary=vocab_list)
results.append(output_transcription)
return results
def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
vocab_list):
"""Initialize the external scorer.
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param language_model_path: Filepath for language model. If it is
empty, the external scorer will be set to
None, and the decoding method will be pure
beam search without scorer.
:type language_model_path: str|None
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
"""
# init once
if self._ext_scorer != None:
return
if language_model_path != '':
self.logger.info("begin to initialize the external scorer "
"for decoding")
self._ext_scorer = Scorer(beam_alpha, beam_beta,
language_model_path, vocab_list)
lm_char_based = self._ext_scorer.is_character_based()
lm_max_order = self._ext_scorer.get_max_order()
lm_dict_size = self._ext_scorer.get_dict_size()
self.logger.info("language model: "
"is_character_based = %d," % lm_char_based +
" max_order = %d," % lm_max_order +
" dict_size = %d" % lm_dict_size)
self.logger.info("end initializing scorer")
else:
self._ext_scorer = None
self.logger.info("no language model provided, "
"decoding by pure beam search without scorer.")
def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n,
vocab_list, num_processes):
"""Decode by beam search for a batch of probs matrix input.
:param probs_split: List of 2-D probability matrix, and each consists
of prob vectors for one speech utterancce.
:param probs_split: List of matrix
:param beam_alpha: Parameter associated with language model.
:type beam_alpha: float
:param beam_beta: Parameter associated with word count.
:type beam_beta: float
:param beam_size: Width for Beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param vocab_list: List of tokens in the vocabulary, for decoding.
:type vocab_list: list
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:return: List of transcription texts.
:rtype: List of str
"""
if self._ext_scorer != None:
self._ext_scorer.reset_params(beam_alpha, beam_beta)
# beam search decode
num_processes = min(num_processes, len(probs_split))
beam_search_results = ctc_beam_search_decoder_batch(
probs_split=probs_split,
vocabulary=vocab_list,
beam_size=beam_size,
num_processes=num_processes,
ext_scoring_func=self._ext_scorer,
cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n)
results = [result[0][1] for result in beam_search_results]
return results
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
decoding_method):
if decoding_method == "ctc_beam_search":
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list)
@paddle.no_grad()
def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes):
_, probs, _ = self.predict(audio, audio_len)
if decoding_method == "ctc_greedy":
result_transcripts = self._decode_batch_greedy(
probs_split=probs, vocab_list=vocab_list)
elif decoding_method == "ctc_beam_search":
result_transcripts = self._decode_batch_beam_search(
probs_split=probs,
beam_alpha=beam_alpha,
beam_beta=beam_beta,
beam_size=beam_size,
cutoff_prob=cutoff_prob,
cutoff_top_n=cutoff_top_n,
vocab_list=vocab_list,
num_processes=num_processes)
else:
raise ValueError(f"Not support: {decoding_method}")
return result_transcripts
class DeepSpeech2Loss(nn.Layer):

@ -1,279 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import logging
from pathlib import Path
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from tensorboardX import SummaryWriter
from collections import defaultdict
import parakeet
from parakeet.utils import checkpoint, mp_tools
__all__ = ["ExperimentBase"]
class ExperimentBase(object):
"""
An experiment template in order to structure the training code and take
care of saving, loading, logging, visualization stuffs. It's intended to
be flexible and simple.
So it only handles output directory (create directory for the output,
create a checkpoint directory, dump the config in use and create
visualizer and logger) in a standard way without enforcing any
input-output protocols to the model and dataloader. It leaves the main
part for the user to implement their own (setup the model, criterion,
optimizer, define a training step, define a validation function and
customize all the text and visual logs).
It does not save too much boilerplate code. The users still have to write
the forward/backward/update mannually, but they are free to add
non-standard behaviors if needed.
We have some conventions to follow.
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
``valid_loader``, ``config`` and ``args`` attributes.
2. The config should have a ``training`` field, which has
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
used as the trigger to invoke validation, checkpointing and stop of the
experiment.
3. There are four methods, namely ``train_batch``, ``valid``,
``setup_model`` and ``setup_dataloader`` that should be implemented.
Feel free to add/overwrite other methods and standalone functions if you
need.
Parameters
----------
config: yacs.config.CfgNode
The configuration used for the experiment.
args: argparse.Namespace
The parsed command line arguments.
Examples
--------
>>> def main_sp(config, args):
>>> exp = Experiment(config, args)
>>> exp.setup()
>>> exp.run()
>>>
>>> config = get_cfg_defaults()
>>> parser = default_argument_parser()
>>> args = parser.parse_args()
>>> if args.config:
>>> config.merge_from_file(args.config)
>>> if args.opts:
>>> config.merge_from_list(args.opts)
>>> config.freeze()
>>>
>>> if args.nprocs > 1 and args.device == "gpu":
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
>>> else:
>>> main_sp(config, args)
"""
def __init__(self, config, args):
self.config = config
self.args = args
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
if self.parallel:
self.init_parallel()
self.setup_output_dir()
self.dump_config()
self.setup_visualizer()
self.setup_logger()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
@property
def parallel(self):
"""A flag indicating whether the experiment should run with
multiprocessing.
"""
return self.args.device == "gpu" and self.args.nprocs > 1
def init_parallel(self):
"""Init environment for multiprocess training.
"""
dist.init_parallel_env()
def save(self):
"""Save checkpoint (model parameters and optimizer states).
"""
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
self.model, self.optimizer)
def load_or_resume(self):
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
iteration = checkpoint.load_parameters(
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path)
self.iteration = iteration
def read_batch(self):
"""Read a batch from the train_loader.
Returns
-------
List[Tensor]
A batch.
"""
try:
batch = next(self.iterator)
except StopIteration:
self.new_epoch()
batch = next(self.iterator)
return batch
def new_epoch(self):
"""Reset the train loader and increment ``epoch``.
"""
self.epoch += 1
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.iterator = iter(self.train_loader)
def train(self):
"""The training process.
It includes forward/backward/update and periodical validation and
saving.
"""
self.new_epoch()
while self.iteration < self.config.training.max_iteration:
self.iteration += 1
self.train_batch()
if self.iteration % self.config.training.valid_interval == 0:
self.valid()
if self.iteration % self.config.training.save_interval == 0:
self.save()
def run(self):
"""The routine of the experiment after setup. This method is intended
to be used by the user.
"""
self.load_or_resume()
try:
self.train()
except KeyboardInterrupt:
self.save()
exit(-1)
@mp_tools.rank_zero_only
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
@mp_tools.rank_zero_only
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
self.checkpoint_dir = checkpoint_dir
@mp_tools.rank_zero_only
def setup_visualizer(self):
"""Initialize a visualizer to log the experiment.
The visual log is saved in the output directory.
Notes
------
Only the main process has a visualizer with it. Use multiple
visualizers in multiprocess to write to a same log file may cause
unexpected behaviors.
"""
# visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir))
self.visualizer = visualizer
def setup_logger(self):
"""Initialize a text logger to log the experiment.
Each process has its own text logger. The logging message is write to
the standard output and a text file named ``worker_n.log`` in the
output directory, where ``n`` means the rank of the process.
"""
logger = logging.getLogger(__name__)
logger.setLevel("INFO")
logger.addHandler(logging.StreamHandler())
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
logger.addHandler(logging.FileHandler(str(log_file)))
self.logger = logger
@mp_tools.rank_zero_only
def dump_config(self):
"""Save the configuration used for this experiment.
It is saved in to ``config.yaml`` in the output directory at the
beginning of the experiment.
"""
with open(self.output_dir / "config.yaml", 'wt') as f:
print(self.config, file=f)
def train_batch(self):
"""The training loop. A subclass should implement this method.
"""
raise NotImplementedError("train_batch should be implemented.")
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
"""The validation. A subclass should implement this method.
"""
raise NotImplementedError("valid should be implemented.")
def setup_model(self):
"""Setup model, criterion and optimizer, etc. A subclass should
implement this method.
"""
raise NotImplementedError("setup_model should be implemented.")
def setup_dataloader(self):
"""Setup training dataloader and validation dataloader. A subclass
should implement this method.
"""
raise NotImplementedError("setup_dataloader should be implemented.")

@ -13,119 +13,38 @@
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
import io
import logging
import argparse
import functools
import paddle.fluid as fluid
from data_utils.data import DataGenerator
from model_utils.model import DeepSpeech2Model
from model_utils.model_check import check_cuda, check_version
from paddle import distributed as dist
from utils.utility import print_arguments
from training.cli import default_argument_parser
from model_utils.config import get_cfg_defaults
from model_utils.model import DeepSpeech2Trainer as Trainer
from utils.error_rate import char_errors, word_errors
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 128, "Minibatch size.")
add_arg('beam_size', int, 500, "Beam search width.")
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
add_arg('use_gpu', bool, True, "Use GPU or not.")
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
"bi-directional RNNs. Not for GRU.")
add_arg('test_manifest', str,
'data/librispeech/manifest.test-clean',
"Filepath of manifest to evaluate.")
add_arg('mean_std_path', str,
'data/librispeech/mean_std.npz',
"Filepath of normalizer's mean & std.")
add_arg('vocab_path', str,
'data/librispeech/vocab.txt',
"Filepath of vocabulary.")
add_arg('model_path', str,
'./checkpoints/libri/step_final',
"If None, the training starts from scratch, "
"otherwise, it resumes from the pre-trained model.")
add_arg('lang_model_path', str,
'models/lm/common_crawl_00.prune01111.trie.klm',
"Filepath for language model.")
add_arg('decoding_method', str,
'ctc_beam_search',
"Decoding method. Options: ctc_beam_search, ctc_greedy",
choices = ['ctc_beam_search', 'ctc_greedy'])
add_arg('error_rate_type', str,
'wer',
"Error rate type for evaluation.",
choices=['wer', 'cer'])
add_arg('specgram_type', str,
'linear',
"Audio feature type. Options: linear, mfcc.",
choices=['linear', 'mfcc'])
# yapf: disable
args = parser.parse_args()
def evaluate():
"""Evaluate on whole test data for DeepSpeech2."""
# check if set use_gpu=True in paddlepaddle cpu version
check_cuda(args.use_gpu)
# check if paddlepaddle version is satisfied
check_version()
if args.use_gpu:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
data_generator = DataGenerator(
vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config='{}',
specgram_type=args.specgram_type,
keep_transcription_text=True,
place = place,
is_training = False)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.test_manifest,
batch_size=args.batch_size,
sortagrad=False,
shuffle_method=None)
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_layer_size=args.rnn_layer_size,
use_gru=args.use_gru,
share_rnn_weights=args.share_rnn_weights,
place=place,
init_from_pretrained_model=args.model_path)
# decoders only accept string encoded in utf-8
vocab_list = [chars for chars in data_generator.vocab_list]
if args.decoding_method == "ctc_beam_search":
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
vocab_list)
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
errors_sum, len_refs, num_ins = 0.0, 0, 0
ds2_model.logger.info("start evaluation ...")
for infer_data in batch_reader():
probs_split = ds2_model.infer_batch_probs(
infer_data=infer_data,
feeding_dict=data_generator.feeding)
infer_data=infer_data, feeding_dict=data_generator.feeding)
if args.decoding_method == "ctc_greedy":
result_transcripts = ds2_model.decode_batch_greedy(
probs_split=probs_split,
vocab_list=vocab_list)
probs_split=probs_split, vocab_list=vocab_list)
else:
result_transcripts = ds2_model.decode_batch_beam_search(
probs_split=probs_split,
@ -136,6 +55,7 @@ def evaluate():
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
num_processes=args.num_proc_bsearch)
target_transcripts = infer_data[1]
for target, result in zip(target_transcripts, result_transcripts):
@ -145,15 +65,38 @@ def evaluate():
num_ins += 1
print("Error rate [%s] (%d/?) = %f" %
(args.error_rate_type, num_ins, errors_sum / len_refs))
print("Final error rate [%s] (%d/%d) = %f" %
(args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
ds2_model.logger.info("finish evaluation")
def main():
print_arguments(args)
evaluate()
def main_sp(config, args):
exp = Trainer(config, args)
exp.setup()
exp.run()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args)
if __name__ == '__main__':
main()
# https://yaml.org/type/float.html
config = get_cfg_defaults()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)

@ -121,6 +121,7 @@ class Trainer():
"""
dist.init_parallel_env()
@mp_tools.rank_zero_only
def save(self):
"""Save checkpoint (model parameters and optimizer states).
"""
@ -190,8 +191,9 @@ class Trainer():
except KeyboardInterrupt:
self.save()
exit(-1)
finally:
self.destory()
@mp_tools.rank_zero_only
def setup_output_dir(self):
"""Create a directory used for output.
"""
@ -201,7 +203,6 @@ class Trainer():
self.output_dir = output_dir
@mp_tools.rank_zero_only
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
@ -213,6 +214,11 @@ class Trainer():
self.checkpoint_dir = checkpoint_dir
@mp_tools.rank_zero_only
def destory(self):
# https://github.com/pytorch/fairseq/issues/2357
self.visualizer.close()
@mp_tools.rank_zero_only
def setup_visualizer(self):
"""Initialize a visualizer to log the experiment.

@ -14,6 +14,7 @@
import os
import time
import logging
import numpy as np
import paddle
from paddle import distributed as dist
@ -22,6 +23,9 @@ from paddle.optimizer import Optimizer
from utils import mp_tools
logger = logging.getLogger(__name__)
logger.setLevel("INFO")
__all__ = ["load_parameters", "save_parameters"]
@ -94,15 +98,15 @@ def load_parameters(model,
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
print(
logger.info(
"[checkpoint] Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
logger.info("[checkpoint] Rank {}: loaded optimizer state from {}".
format(rank, optimizer_path))
return iteration
@ -124,12 +128,13 @@ def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams"
paddle.save(model_dict, params_path)
print("[checkpoint] Saved model to {}".format(params_path))
logger.info("[checkpoint] Saved model to {}".format(params_path))
if optimizer:
opt_dict = optimizer.state_dict()
optimizer_path = checkpoint_path + ".pdopt"
paddle.save(opt_dict, optimizer_path)
print("[checkpoint] Saved optimzier state to {}".format(optimizer_path))
logger.info(
"[checkpoint] Saved optimzier state to {}".format(optimizer_path))
_save_checkpoint(checkpoint_dir, iteration)

@ -20,13 +20,12 @@ __all__ = ["rank_zero_only"]
def rank_zero_only(func):
rank = dist.get_rank()
@wraps(func)
def wrapper(*args, **kwargs):
rank = dist.get_rank()
if rank != 0:
return
result = func(*args, **kwargs)
return result
return wrapper
return wrapper

Loading…
Cancel
Save