revise example/ting/s1

pull/665/head
Haoxin Ma 4 years ago
parent b9110af9d3
commit 7bae32f384

@ -72,7 +72,7 @@ _C.collator =CN(
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20, target_dB=-20,
dither=1.0, # feature dither dither=1.0, # feature dither
keep_transcription_text=True keep_transcription_text=False
)) ))
DeepSpeech2Model.params(_C.model) DeepSpeech2Model.params(_C.model)

@ -336,13 +336,14 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
# config.data.max_output_input_ratio = float('inf') # config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config) test_dataset = ManifestDataset.from_config(config)
config.collator.keep_transcription_text = True
# return text ord id # return text ord id
self.test_loader = DataLoader( self.test_loader = DataLoader(
test_dataset, test_dataset,
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(config=config, keep_transcription_text=True)) collate_fn=SpeechCollator.from_config(config))
logger.info("Setup test Dataloader!") logger.info("Setup test Dataloader!")
def setup_output_dir(self): def setup_output_dir(self):

@ -22,6 +22,13 @@ _C = CfgNode()
_C.data = ManifestDataset.params() _C.data = ManifestDataset.params()
_C.collator =CfgNode(
dict(
augmentation_config="",
unit_type="char",
keep_transcription_text=False
))
_C.model = U2Model.params() _C.model = U2Model.params()
_C.training = U2Trainer.params() _C.training = U2Trainer.params()

@ -221,7 +221,7 @@ class U2Trainer(Trainer):
config.data.augmentation_config = "" config.data.augmentation_config = ""
dev_dataset = ManifestDataset.from_config(config) dev_dataset = ManifestDataset.from_config(config)
collate_fn = SpeechCollator(keep_transcription_text=False) collate_fn = SpeechCollator.from_config(config)
if self.parallel: if self.parallel:
batch_sampler = SortagradDistributedBatchSampler( batch_sampler = SortagradDistributedBatchSampler(
train_dataset, train_dataset,
@ -266,12 +266,13 @@ class U2Trainer(Trainer):
# config.data.max_output_input_ratio = float('inf') # config.data.max_output_input_ratio = float('inf')
test_dataset = ManifestDataset.from_config(config) test_dataset = ManifestDataset.from_config(config)
# return text ord id # return text ord id
config.collator.keep_transcription_text = True
self.test_loader = DataLoader( self.test_loader = DataLoader(
test_dataset, test_dataset,
batch_size=config.decoding.batch_size, batch_size=config.decoding.batch_size,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True)) collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test Dataloader!") logger.info("Setup train/valid/test Dataloader!")
def setup_model(self): def setup_model(self):
@ -375,7 +376,7 @@ class U2Tester(U2Trainer):
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time() start_time = time.time()
text_feature = self.test_loader.dataset.text_feature text_feature = self.test_loader.collate_fn.text_feature
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
@ -423,7 +424,7 @@ class U2Tester(U2Trainer):
self.model.eval() self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}") logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
stride_ms = self.test_loader.dataset.stride_ms stride_ms = self.config.collator.stride_ms
error_rate_type = None error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0 num_frames = 0.0

@ -82,7 +82,7 @@ def read_manifest(
] ]
if all(conditions): if all(conditions):
manifest.append(json_data) manifest.append(json_data)
return manifest, json_data["feat_shape"][-1] return manifest
def rms_to_db(rms: float): def rms_to_db(rms: float):

@ -56,7 +56,7 @@ class SpeechCollator():
use_dB_normalization=True, use_dB_normalization=True,
target_dB=-20, target_dB=-20,
dither=1.0, # feature dither dither=1.0, # feature dither
keep_transcription_text=True keep_transcription_text=False
)) ))
if config is not None: if config is not None:
@ -75,7 +75,7 @@ class SpeechCollator():
""" """
assert 'augmentation_config' in config.collator assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator assert 'mean_std_filepath' in config.data
assert 'vocab_filepath' in config.data assert 'vocab_filepath' in config.data
assert 'specgram_type' in config.collator assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator assert 'n_fft' in config.collator
@ -94,7 +94,7 @@ class SpeechCollator():
speech_collator = cls( speech_collator = cls(
aug_file=aug_file, aug_file=aug_file,
random_seed=0, random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath, mean_std_filepath=config.data.mean_std_filepath,
unit_type=config.collator.unit_type, unit_type=config.collator.unit_type,
vocab_filepath=config.data.vocab_filepath, vocab_filepath=config.data.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix, spm_model_prefix=config.collator.spm_model_prefix,
@ -282,26 +282,11 @@ class SpeechCollator():
text_lens = np.array(text_lens).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property @property
def text_feature(self): def text_feature(self):
return self._text_featurizer return self._speech_featurizer.text_feature
self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property @property
def stride_ms(self): def stride_ms(self):

@ -161,7 +161,7 @@ class ManifestDataset(Dataset):
# self._rng = np.random.RandomState(random_seed) # self._rng = np.random.RandomState(random_seed)
# read manifest # read manifest
self._manifest, self._feature_size = read_manifest( self._manifest = read_manifest(
manifest_path=manifest_path, manifest_path=manifest_path,
max_input_len=max_input_len, max_input_len=max_input_len,
min_input_len=min_input_len, min_input_len=min_input_len,
@ -213,16 +213,8 @@ class ManifestDataset(Dataset):
Returns: Returns:
int: audio feature size. int: audio feature size.
""" """
return self._feature_size return self._manifest[0]["feat_shape"][-1]
@property
def stride_ms(self):
"""time length in `ms` unit per frame
Returns:
float: time(ms)/frame
"""
return self._audio_featurizer.stride_ms
def __len__(self): def __len__(self):

@ -6,7 +6,6 @@ data:
mean_std_filepath: data/mean_std.json mean_std_filepath: data/mean_std.json
unit_type: char unit_type: char
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.json
batch_size: 4 batch_size: 4
min_input_len: 0.0 min_input_len: 0.0
max_input_len: 27.0 max_input_len: 27.0
@ -14,18 +13,6 @@ data:
max_output_len: 400.0 max_output_len: 400.0
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
specgram_type: linear
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0
window_ms: 20.0
delta_delta: False
dither: 1.0
use_dB_normalization: True
target_dB: -20
random_seed: 0
keep_transcription_text: False
sortagrad: True sortagrad: True
shuffle_method: batch_shuffle shuffle_method: batch_shuffle
num_workers: 0 num_workers: 0
@ -33,7 +20,6 @@ data:
collator: collator:
augmentation_config: conf/augmentation.json augmentation_config: conf/augmentation.json
random_seed: 0 random_seed: 0
mean_std_filepath: data/mean_std.json
spm_model_prefix: spm_model_prefix:
specgram_type: linear specgram_type: linear
feat_dim: feat_dim:
@ -46,7 +32,7 @@ collator:
use_dB_normalization: True use_dB_normalization: True
target_dB: -20 target_dB: -20
dither: 1.0 dither: 1.0
keep_transcription_text: True keep_transcription_text: False
model: model:
num_conv_layers: 2 num_conv_layers: 2

@ -7,7 +7,6 @@ data:
unit_type: 'spm' unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200' spm_model_prefix: 'data/bpe_unigram_200'
mean_std_filepath: "" mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4 batch_size: 4
min_input_len: 0.5 # second min_input_len: 0.5 # second
max_input_len: 20.0 # second max_input_len: 20.0 # second
@ -16,23 +15,26 @@ data:
min_output_input_ratio: 0.05 min_output_input_ratio: 0.05
max_output_input_ratio: 10.0 max_output_input_ratio: 10.0
raw_wav: True # use raw_wav or kaldi feature raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank sortagrad: True
shuffle_method: batch_shuffle
num_workers: 0 #2
collator:
augmentation_config: conf/augmentation.json
random_seed: 0
spm_model_prefix:
specgram_type: fbank
feat_dim: 80 feat_dim: 80
delta_delta: False delta_delta: False
dither: 1.0
target_sample_rate: 16000
max_freq: None
n_fft: None
stride_ms: 10.0 stride_ms: 10.0
window_ms: 25.0 window_ms: 20.0
n_fft: None
max_freq: None
target_sample_rate: 16000
use_dB_normalization: True use_dB_normalization: True
target_dB: -20 target_dB: -20
random_seed: 0 dither: 1.0
keep_transcription_text: False keep_transcription_text: False
sortagrad: True
shuffle_method: batch_shuffle
num_workers: 2
# network architecture # network architecture
model: model:
@ -70,7 +72,7 @@ model:
training: training:
n_epoch: 2 n_epoch: 3
accum_grad: 1 accum_grad: 1
global_grad_clip: 5.0 global_grad_clip: 5.0
optim: adam optim: adam

Loading…
Cancel
Save