diff --git a/deepspeech/exps/deepspeech2/config.py b/deepspeech/exps/deepspeech2/config.py index 7d2250fc..2f0f5c24 100644 --- a/deepspeech/exps/deepspeech2/config.py +++ b/deepspeech/exps/deepspeech2/config.py @@ -13,12 +13,11 @@ # limitations under the License. from yacs.config import CfgNode -from deepspeech.models.deepspeech2 import DeepSpeech2Model -from deepspeech.io.dataset import ManifestDataset -from deepspeech.io.collator import SpeechCollator -from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester - +from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer +from deepspeech.io.collator import SpeechCollator +from deepspeech.io.dataset import ManifestDataset +from deepspeech.models.deepspeech2 import DeepSpeech2Model _C = CfgNode() diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index c11d1e25..deb8752b 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -15,11 +15,13 @@ import time from collections import defaultdict from pathlib import Path +from typing import Optional import numpy as np import paddle from paddle import distributed as dist from paddle.io import DataLoader +from yacs.config import CfgNode from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset @@ -33,9 +35,6 @@ from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools from deepspeech.utils.log import Log - -from typing import Optional -from yacs.config import CfgNode logger = Log(__name__).getlog() @@ -44,13 +43,13 @@ class DeepSpeech2Trainer(Trainer): def params(cls, config: Optional[CfgNode]=None) -> CfgNode: # training config default = CfgNode( - dict( - lr=5e-4, # learning rate - lr_decay=1.0, # learning rate decay - weight_decay=1e-6, # the coeff of weight decay - global_grad_clip=5.0, # the global norm clip - n_epoch=50, # train epochs - )) + dict( + lr=5e-4, # learning rate + lr_decay=1.0, # learning rate decay + weight_decay=1e-6, # the coeff of weight decay + global_grad_clip=5.0, # the global norm clip + n_epoch=50, # train epochs + )) if config is not None: config.merge_from_other_cfg(default) @@ -184,7 +183,6 @@ class DeepSpeech2Trainer(Trainer): collate_fn_train = SpeechCollator.from_config(config) - config.collator.augmentation_config = "" collate_fn_dev = SpeechCollator.from_config(config) self.train_loader = DataLoader( @@ -206,18 +204,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): def params(cls, config: Optional[CfgNode]=None) -> CfgNode: # testing config default = CfgNode( - dict( - alpha=2.5, # Coef of LM for beam search. - beta=0.3, # Coef of WC for beam search. - cutoff_prob=1.0, # Cutoff probability for pruning. - cutoff_top_n=40, # Cutoff number for pruning. - lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. - decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy - error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' - num_proc_bsearch=8, # # of CPUs for beam search. - beam_size=500, # Beam search width. - batch_size=128, # decoding batch size - )) + dict( + alpha=2.5, # Coef of LM for beam search. + beta=0.3, # Coef of WC for beam search. + cutoff_prob=1.0, # Cutoff probability for pruning. + cutoff_top_n=40, # Cutoff number for pruning. + lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. + decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy + error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' + num_proc_bsearch=8, # # of CPUs for beam search. + beam_size=500, # Beam search width. + batch_size=128, # decoding batch size + )) if config is not None: config.merge_from_other_cfg(default) @@ -235,7 +233,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None): + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -257,7 +261,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) - for utt, target, result in zip(utts, target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref @@ -287,7 +292,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): with open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): utts, audio, audio_len, texts, texts_len = batch - metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout) + metrics = self.compute_metrics(utts, audio, audio_len, texts, + texts_len, fout) errors_sum += metrics['errors_sum'] len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] diff --git a/deepspeech/exps/u2/config.py b/deepspeech/exps/u2/config.py index d8735453..4ec7bd19 100644 --- a/deepspeech/exps/u2/config.py +++ b/deepspeech/exps/u2/config.py @@ -15,9 +15,9 @@ from yacs.config import CfgNode from deepspeech.exps.u2.model import U2Tester from deepspeech.exps.u2.model import U2Trainer +from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset from deepspeech.models.u2 import U2Model -from deepspeech.io.collator import SpeechCollator _C = CfgNode() diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 836afa36..05551875 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -78,7 +78,8 @@ class U2Trainer(Trainer): start = time.time() utt, audio, audio_len, text, text_len = batch_data - loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad loss.backward() @@ -121,7 +122,8 @@ class U2Trainer(Trainer): total_loss = 0.0 for i, batch in enumerate(self.valid_loader): utt, audio, audio_len, text, text_len = batch - loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, + text_len) if paddle.isfinite(loss): num_utts = batch[1].shape[0] num_seen_utts += num_utts @@ -221,7 +223,7 @@ class U2Trainer(Trainer): dev_dataset = ManifestDataset.from_config(config) collate_fn_train = SpeechCollator.from_config(config) - + config.collator.augmentation_config = "" collate_fn_dev = SpeechCollator.from_config(config) @@ -372,7 +374,13 @@ class U2Tester(U2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None): + def compute_metrics(self, + utts, + audio, + audio_len, + texts, + texts_len, + fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -399,7 +407,8 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for utt, target, result in zip(utts, target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, + result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index ab1e9165..ecf7024c 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -11,21 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import io +import time +from collections import namedtuple +from typing import Optional + import numpy as np +from yacs.config import CfgNode -from deepspeech.frontend.utility import IGNORE_ID -from deepspeech.io.utility import pad_sequence -from deepspeech.utils.log import Log from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer from deepspeech.frontend.normalizer import FeatureNormalizer from deepspeech.frontend.speech import SpeechSegment -import io -import time -from yacs.config import CfgNode -from typing import Optional - -from collections import namedtuple +from deepspeech.frontend.utility import IGNORE_ID +from deepspeech.io.utility import pad_sequence +from deepspeech.utils.log import Log __all__ = ["SpeechCollator"] @@ -34,6 +34,7 @@ logger = Log(__name__).getlog() # namedtupe need global for pickle. TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) + class SpeechCollator(): @classmethod def params(cls, config: Optional[CfgNode]=None) -> CfgNode: @@ -56,8 +57,7 @@ class SpeechCollator(): use_dB_normalization=True, target_dB=-20, dither=1.0, # feature dither - keep_transcription_text=False - )) + keep_transcription_text=False)) if config is not None: config.merge_from_other_cfg(default) @@ -84,7 +84,9 @@ class SpeechCollator(): if isinstance(config.collator.augmentation_config, (str, bytes)): if config.collator.augmentation_config: aug_file = io.open( - config.collator.augmentation_config, mode='r', encoding='utf8') + config.collator.augmentation_config, + mode='r', + encoding='utf8') else: aug_file = io.StringIO(initial_value='{}', newline='') else: @@ -92,43 +94,46 @@ class SpeechCollator(): assert isinstance(aug_file, io.StringIO) speech_collator = cls( - aug_file=aug_file, - random_seed=0, - mean_std_filepath=config.collator.mean_std_filepath, - unit_type=config.collator.unit_type, - vocab_filepath=config.collator.vocab_filepath, - spm_model_prefix=config.collator.spm_model_prefix, - specgram_type=config.collator.specgram_type, - feat_dim=config.collator.feat_dim, - delta_delta=config.collator.delta_delta, - stride_ms=config.collator.stride_ms, - window_ms=config.collator.window_ms, - n_fft=config.collator.n_fft, - max_freq=config.collator.max_freq, - target_sample_rate=config.collator.target_sample_rate, - use_dB_normalization=config.collator.use_dB_normalization, - target_dB=config.collator.target_dB, - dither=config.collator.dither, - keep_transcription_text=config.collator.keep_transcription_text - ) + aug_file=aug_file, + random_seed=0, + mean_std_filepath=config.collator.mean_std_filepath, + unit_type=config.collator.unit_type, + vocab_filepath=config.collator.vocab_filepath, + spm_model_prefix=config.collator.spm_model_prefix, + specgram_type=config.collator.specgram_type, + feat_dim=config.collator.feat_dim, + delta_delta=config.collator.delta_delta, + stride_ms=config.collator.stride_ms, + window_ms=config.collator.window_ms, + n_fft=config.collator.n_fft, + max_freq=config.collator.max_freq, + target_sample_rate=config.collator.target_sample_rate, + use_dB_normalization=config.collator.use_dB_normalization, + target_dB=config.collator.target_dB, + dither=config.collator.dither, + keep_transcription_text=config.collator.keep_transcription_text) return speech_collator - def __init__(self, aug_file, mean_std_filepath, - vocab_filepath, spm_model_prefix, - random_seed=0, - unit_type="char", - specgram_type='linear', # 'linear', 'mfcc', 'fbank' - feat_dim=0, # 'mfcc', 'fbank' - delta_delta=False, # 'mfcc', 'fbank' - stride_ms=10.0, # ms - window_ms=20.0, # ms - n_fft=None, # fft points - max_freq=None, # None for samplerate/2 - target_sample_rate=16000, # target sample rate - use_dB_normalization=True, - target_dB=-20, - dither=1.0, - keep_transcription_text=True): + def __init__( + self, + aug_file, + mean_std_filepath, + vocab_filepath, + spm_model_prefix, + random_seed=0, + unit_type="char", + specgram_type='linear', # 'linear', 'mfcc', 'fbank' + feat_dim=0, # 'mfcc', 'fbank' + delta_delta=False, # 'mfcc', 'fbank' + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + target_sample_rate=16000, # target sample rate + use_dB_normalization=True, + target_dB=-20, + dither=1.0, + keep_transcription_text=True): """SpeechCollator Collator Args: @@ -159,9 +164,8 @@ class SpeechCollator(): self._local_data = TarLocalData(tar2info={}, tar2object={}) self._augmentation_pipeline = AugmentationPipeline( - augmentation_config=aug_file.read(), - random_seed=random_seed) - + augmentation_config=aug_file.read(), random_seed=random_seed) + self._normalizer = FeatureNormalizer( mean_std_filepath) if mean_std_filepath else None @@ -290,8 +294,6 @@ class SpeechCollator(): text_lens = np.array(text_lens).astype(np.int64) return utts, padded_audios, audio_lens, padded_texts, text_lens - - @property def manifest(self): return self._manifest @@ -318,4 +320,4 @@ class SpeechCollator(): @property def stride_ms(self): - return self._speech_featurizer.stride_ms \ No newline at end of file + return self._speech_featurizer.stride_ms diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index 70383b4d..92c60f35 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -12,19 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import io -import tarfile -import time -from collections import namedtuple from typing import Optional -import numpy as np from paddle.io import Dataset from yacs.config import CfgNode -from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline -from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer -from deepspeech.frontend.normalizer import FeatureNormalizer -from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.utility import read_manifest from deepspeech.utils.log import Log @@ -46,8 +38,7 @@ class ManifestDataset(Dataset): max_output_len=float('inf'), min_output_len=0.0, max_output_input_ratio=float('inf'), - min_output_input_ratio=0.0, - )) + min_output_input_ratio=0.0, )) if config is not None: config.merge_from_other_cfg(default) @@ -66,7 +57,6 @@ class ManifestDataset(Dataset): assert 'manifest' in config.data assert config.data.manifest - dataset = cls( manifest_path=config.data.manifest, max_input_len=config.data.max_input_len, @@ -74,8 +64,7 @@ class ManifestDataset(Dataset): max_output_len=config.data.max_output_len, min_output_len=config.data.min_output_len, max_output_input_ratio=config.data.max_output_input_ratio, - min_output_input_ratio=config.data.min_output_input_ratio, - ) + min_output_input_ratio=config.data.min_output_input_ratio, ) return dataset def __init__(self, @@ -111,7 +100,6 @@ class ManifestDataset(Dataset): min_output_input_ratio=min_output_input_ratio) self._manifest.sort(key=lambda x: x["feat_shape"][0]) - def __len__(self): return len(self._manifest) diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index bcfddaef..238e2d35 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -905,7 +905,6 @@ class U2InferModel(U2Model): def __init__(self, configs: dict): super().__init__(configs) - def forward(self, feats, feats_lengths,