fix pre-commit

pull/665/head
Haoxin Ma 4 years ago
parent 089a8ed602
commit 3a743f3717

@ -13,12 +13,11 @@
# limitations under the License.
from yacs.config import CfgNode
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.collator import SpeechCollator
from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester
from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model
_C = CfgNode()

@ -15,11 +15,13 @@
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
@ -33,9 +35,6 @@ from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils.log import Log
from typing import Optional
from yacs.config import CfgNode
logger = Log(__name__).getlog()
@ -184,7 +183,6 @@ class DeepSpeech2Trainer(Trainer):
collate_fn_train = SpeechCollator.from_config(config)
config.collator.augmentation_config = ""
collate_fn_dev = SpeechCollator.from_config(config)
self.train_loader = DataLoader(
@ -235,7 +233,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None):
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
@ -257,7 +261,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
for utt, target, result in zip(utts, target_transcripts, result_transcripts):
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
@ -287,7 +292,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
utts, audio, audio_len, texts, texts_len = batch
metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout)
metrics = self.compute_metrics(utts, audio, audio_len, texts,
texts_len, fout)
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']

@ -15,9 +15,9 @@ from yacs.config import CfgNode
from deepspeech.exps.u2.model import U2Tester
from deepspeech.exps.u2.model import U2Trainer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.u2 import U2Model
from deepspeech.io.collator import SpeechCollator
_C = CfgNode()

@ -78,7 +78,8 @@ class U2Trainer(Trainer):
start = time.time()
utt, audio, audio_len, text, text_len = batch_data
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len)
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
loss.backward()
@ -121,7 +122,8 @@ class U2Trainer(Trainer):
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, audio, audio_len, text, text_len = batch
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len)
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
text_len)
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
@ -372,7 +374,13 @@ class U2Tester(U2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None):
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
@ -399,7 +407,8 @@ class U2Tester(U2Trainer):
simulate_streaming=cfg.simulate_streaming)
decode_time = time.time() - start_time
for utt, target, result in zip(utts, target_transcripts, result_transcripts):
for utt, target, result in zip(utts, target_transcripts,
result_transcripts):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref

@ -11,21 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import time
from collections import namedtuple
from typing import Optional
import numpy as np
from yacs.config import CfgNode
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.io.utility import pad_sequence
from deepspeech.utils.log import Log
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
import io
import time
from yacs.config import CfgNode
from typing import Optional
from collections import namedtuple
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.io.utility import pad_sequence
from deepspeech.utils.log import Log
__all__ = ["SpeechCollator"]
@ -34,6 +34,7 @@ logger = Log(__name__).getlog()
# namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator():
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
@ -56,8 +57,7 @@ class SpeechCollator():
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False
))
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
@ -84,7 +84,9 @@ class SpeechCollator():
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config, mode='r', encoding='utf8')
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
@ -109,12 +111,15 @@ class SpeechCollator():
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text
)
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
def __init__(self, aug_file, mean_std_filepath,
vocab_filepath, spm_model_prefix,
def __init__(
self,
aug_file,
mean_std_filepath,
vocab_filepath,
spm_model_prefix,
random_seed=0,
unit_type="char",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
@ -159,8 +164,7 @@ class SpeechCollator():
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(),
random_seed=random_seed)
augmentation_config=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
@ -290,8 +294,6 @@ class SpeechCollator():
text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def manifest(self):
return self._manifest

@ -12,19 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import tarfile
import time
from collections import namedtuple
from typing import Optional
import numpy as np
from paddle.io import Dataset
from yacs.config import CfgNode
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log
@ -46,8 +38,7 @@ class ManifestDataset(Dataset):
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0,
))
min_output_input_ratio=0.0, ))
if config is not None:
config.merge_from_other_cfg(default)
@ -66,7 +57,6 @@ class ManifestDataset(Dataset):
assert 'manifest' in config.data
assert config.data.manifest
dataset = cls(
manifest_path=config.data.manifest,
max_input_len=config.data.max_input_len,
@ -74,8 +64,7 @@ class ManifestDataset(Dataset):
max_output_len=config.data.max_output_len,
min_output_len=config.data.min_output_len,
max_output_input_ratio=config.data.max_output_input_ratio,
min_output_input_ratio=config.data.min_output_input_ratio,
)
min_output_input_ratio=config.data.min_output_input_ratio, )
return dataset
def __init__(self,
@ -111,7 +100,6 @@ class ManifestDataset(Dataset):
min_output_input_ratio=min_output_input_ratio)
self._manifest.sort(key=lambda x: x["feat_shape"][0])
def __len__(self):
return len(self._manifest)

@ -905,7 +905,6 @@ class U2InferModel(U2Model):
def __init__(self, configs: dict):
super().__init__(configs)
def forward(self,
feats,
feats_lengths,

Loading…
Cancel
Save