fix pre-commit

pull/665/head
Haoxin Ma 4 years ago
parent 089a8ed602
commit 3a743f3717

@ -13,12 +13,11 @@
# limitations under the License. # limitations under the License.
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.collator import SpeechCollator
from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester
from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model
_C = CfgNode() _C = CfgNode()

@ -15,11 +15,13 @@
import time import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Optional
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.io.collator import SpeechCollator from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
@ -33,9 +35,6 @@ from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from typing import Optional
from yacs.config import CfgNode
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -44,13 +43,13 @@ class DeepSpeech2Trainer(Trainer):
def params(cls, config: Optional[CfgNode]=None) -> CfgNode: def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# training config # training config
default = CfgNode( default = CfgNode(
dict( dict(
lr=5e-4, # learning rate lr=5e-4, # learning rate
lr_decay=1.0, # learning rate decay lr_decay=1.0, # learning rate decay
weight_decay=1e-6, # the coeff of weight decay weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip global_grad_clip=5.0, # the global norm clip
n_epoch=50, # train epochs n_epoch=50, # train epochs
)) ))
if config is not None: if config is not None:
config.merge_from_other_cfg(default) config.merge_from_other_cfg(default)
@ -184,7 +183,6 @@ class DeepSpeech2Trainer(Trainer):
collate_fn_train = SpeechCollator.from_config(config) collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = "" config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config) collate_fn_dev = SpeechCollator.from_config(config)
self.train_loader = DataLoader( self.train_loader = DataLoader(
@ -206,18 +204,18 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
def params(cls, config: Optional[CfgNode]=None) -> CfgNode: def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# testing config # testing config
default = CfgNode( default = CfgNode(
dict( dict(
alpha=2.5, # Coef of LM for beam search. alpha=2.5, # Coef of LM for beam search.
beta=0.3, # Coef of WC for beam search. beta=0.3, # Coef of WC for beam search.
cutoff_prob=1.0, # Cutoff probability for pruning. cutoff_prob=1.0, # Cutoff probability for pruning.
cutoff_top_n=40, # Cutoff number for pruning. cutoff_top_n=40, # Cutoff number for pruning.
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model. lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy decoding_method='ctc_beam_search', # Decoding method. Options: ctc_beam_search, ctc_greedy
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer' error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
num_proc_bsearch=8, # # of CPUs for beam search. num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=500, # Beam search width. beam_size=500, # Beam search width.
batch_size=128, # decoding batch size batch_size=128, # decoding batch size
)) ))
if config is not None: if config is not None:
config.merge_from_other_cfg(default) config.merge_from_other_cfg(default)
@ -235,7 +233,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans.append(''.join([chr(i) for i in ids])) trans.append(''.join([chr(i) for i in ids]))
return trans return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None): def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
@ -257,7 +261,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_top_n=cfg.cutoff_top_n, cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch) num_processes=cfg.num_proc_bsearch)
for utt, target, result in zip(utts, target_transcripts, result_transcripts): for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors
len_refs += len_ref len_refs += len_ref
@ -287,7 +292,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
with open(self.args.result_file, 'w') as fout: with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout) metrics = self.compute_metrics(utts, audio, audio_len, texts,
texts_len, fout)
errors_sum += metrics['errors_sum'] errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs'] len_refs += metrics['len_refs']
num_ins += metrics['num_ins'] num_ins += metrics['num_ins']

@ -15,9 +15,9 @@ from yacs.config import CfgNode
from deepspeech.exps.u2.model import U2Tester from deepspeech.exps.u2.model import U2Tester
from deepspeech.exps.u2.model import U2Trainer from deepspeech.exps.u2.model import U2Trainer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
from deepspeech.io.collator import SpeechCollator
_C = CfgNode() _C = CfgNode()

@ -78,7 +78,8 @@ class U2Trainer(Trainer):
start = time.time() start = time.time()
utt, audio, audio_len, text, text_len = batch_data utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad` # loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad loss /= train_conf.accum_grad
loss.backward() loss.backward()
@ -121,7 +122,8 @@ class U2Trainer(Trainer):
total_loss = 0.0 total_loss = 0.0
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
if paddle.isfinite(loss): if paddle.isfinite(loss):
num_utts = batch[1].shape[0] num_utts = batch[1].shape[0]
num_seen_utts += num_utts num_seen_utts += num_utts
@ -221,7 +223,7 @@ class U2Trainer(Trainer):
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
collate_fn_train = SpeechCollator.from_config(config) collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = "" config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config) collate_fn_dev = SpeechCollator.from_config(config)
@ -372,7 +374,13 @@ class U2Tester(U2Trainer):
trans.append(''.join([chr(i) for i in ids])) trans.append(''.join([chr(i) for i in ids]))
return trans return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None): def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
@ -399,7 +407,8 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time decode_time = time.time() - start_time
for utt, target, result in zip(utts, target_transcripts, result_transcripts): for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
errors_sum += errors errors_sum += errors
len_refs += len_ref len_refs += len_ref

@ -11,21 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
import time
from collections import namedtuple
from typing import Optional
import numpy as np import numpy as np
from yacs.config import CfgNode
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.io.utility import pad_sequence
from deepspeech.utils.log import Log
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment from deepspeech.frontend.speech import SpeechSegment
import io from deepspeech.frontend.utility import IGNORE_ID
import time from deepspeech.io.utility import pad_sequence
from yacs.config import CfgNode from deepspeech.utils.log import Log
from typing import Optional
from collections import namedtuple
__all__ = ["SpeechCollator"] __all__ = ["SpeechCollator"]
@ -34,6 +34,7 @@ logger = Log(__name__).getlog()
# namedtupe need global for pickle. # namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object']) TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator(): class SpeechCollator():
@classmethod @classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode: def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
@ -56,8 +57,7 @@ class SpeechCollator():
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20, target_dB=-20,
dither=1.0, # feature dither dither=1.0, # feature dither
keep_transcription_text=False keep_transcription_text=False))
))
if config is not None: if config is not None:
config.merge_from_other_cfg(default) config.merge_from_other_cfg(default)
@ -84,7 +84,9 @@ class SpeechCollator():
if isinstance(config.collator.augmentation_config, (str, bytes)): if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config: if config.collator.augmentation_config:
aug_file = io.open( aug_file = io.open(
config.collator.augmentation_config, mode='r', encoding='utf8') config.collator.augmentation_config,
mode='r',
encoding='utf8')
else: else:
aug_file = io.StringIO(initial_value='{}', newline='') aug_file = io.StringIO(initial_value='{}', newline='')
else: else:
@ -92,43 +94,46 @@ class SpeechCollator():
assert isinstance(aug_file, io.StringIO) assert isinstance(aug_file, io.StringIO)
speech_collator = cls( speech_collator = cls(
aug_file=aug_file, aug_file=aug_file,
random_seed=0, random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath, mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type, unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath, vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix, spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type, specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim, feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta, delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms, stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms, window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft, n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq, max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate, target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization, use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB, target_dB=config.collator.target_dB,
dither=config.collator.dither, dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text keep_transcription_text=config.collator.keep_transcription_text)
)
return speech_collator return speech_collator
def __init__(self, aug_file, mean_std_filepath, def __init__(
vocab_filepath, spm_model_prefix, self,
random_seed=0, aug_file,
unit_type="char", mean_std_filepath,
specgram_type='linear', # 'linear', 'mfcc', 'fbank' vocab_filepath,
feat_dim=0, # 'mfcc', 'fbank' spm_model_prefix,
delta_delta=False, # 'mfcc', 'fbank' random_seed=0,
stride_ms=10.0, # ms unit_type="char",
window_ms=20.0, # ms specgram_type='linear', # 'linear', 'mfcc', 'fbank'
n_fft=None, # fft points feat_dim=0, # 'mfcc', 'fbank'
max_freq=None, # None for samplerate/2 delta_delta=False, # 'mfcc', 'fbank'
target_sample_rate=16000, # target sample rate stride_ms=10.0, # ms
use_dB_normalization=True, window_ms=20.0, # ms
target_dB=-20, n_fft=None, # fft points
dither=1.0, max_freq=None, # None for samplerate/2
keep_transcription_text=True): target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0,
keep_transcription_text=True):
"""SpeechCollator Collator """SpeechCollator Collator
Args: Args:
@ -159,9 +164,8 @@ class SpeechCollator():
self._local_data = TarLocalData(tar2info={}, tar2object={}) self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline( self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(), augmentation_config=aug_file.read(), random_seed=random_seed)
random_seed=random_seed)
self._normalizer = FeatureNormalizer( self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None mean_std_filepath) if mean_std_filepath else None
@ -290,8 +294,6 @@ class SpeechCollator():
text_lens = np.array(text_lens).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens return utts, padded_audios, audio_lens, padded_texts, text_lens
@property @property
def manifest(self): def manifest(self):
return self._manifest return self._manifest
@ -318,4 +320,4 @@ class SpeechCollator():
@property @property
def stride_ms(self): def stride_ms(self):
return self._speech_featurizer.stride_ms return self._speech_featurizer.stride_ms

@ -12,19 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io import io
import tarfile
import time
from collections import namedtuple
from typing import Optional from typing import Optional
import numpy as np
from paddle.io import Dataset from paddle.io import Dataset
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
@ -46,8 +38,7 @@ class ManifestDataset(Dataset):
max_output_len=float('inf'), max_output_len=float('inf'),
min_output_len=0.0, min_output_len=0.0,
max_output_input_ratio=float('inf'), max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0, min_output_input_ratio=0.0, ))
))
if config is not None: if config is not None:
config.merge_from_other_cfg(default) config.merge_from_other_cfg(default)
@ -66,7 +57,6 @@ class ManifestDataset(Dataset):
assert 'manifest' in config.data assert 'manifest' in config.data
assert config.data.manifest assert config.data.manifest
dataset = cls( dataset = cls(
manifest_path=config.data.manifest, manifest_path=config.data.manifest,
max_input_len=config.data.max_input_len, max_input_len=config.data.max_input_len,
@ -74,8 +64,7 @@ class ManifestDataset(Dataset):
max_output_len=config.data.max_output_len, max_output_len=config.data.max_output_len,
min_output_len=config.data.min_output_len, min_output_len=config.data.min_output_len,
max_output_input_ratio=config.data.max_output_input_ratio, max_output_input_ratio=config.data.max_output_input_ratio,
min_output_input_ratio=config.data.min_output_input_ratio, min_output_input_ratio=config.data.min_output_input_ratio, )
)
return dataset return dataset
def __init__(self, def __init__(self,
@ -111,7 +100,6 @@ class ManifestDataset(Dataset):
min_output_input_ratio=min_output_input_ratio) min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0]) self._manifest.sort(key=lambda x: x["feat_shape"][0])
def __len__(self): def __len__(self):
return len(self._manifest) return len(self._manifest)

@ -905,7 +905,6 @@ class U2InferModel(U2Model):
def __init__(self, configs: dict): def __init__(self, configs: dict):
super().__init__(configs) super().__init__(configs)
def forward(self, def forward(self,
feats, feats,
feats_lengths, feats_lengths,

Loading…
Cancel
Save