utt datapipeline

pull/657/head
Haoxin Ma 4 years ago
parent 03e5a64d26
commit c8368410e2

@ -75,7 +75,7 @@ class DeepSpeech2Trainer(Trainer):
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
loss = self.model(*batch) loss = self.model(*batch)
if paddle.isfinite(loss): if paddle.isfinite(loss):
num_utts = batch[0].shape[0] num_utts = batch[1].shape[0]
num_seen_utts += num_utts num_seen_utts += num_utts
total_loss += float(loss) * num_utts total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss)) valid_losses['val_loss'].append(float(loss))
@ -191,7 +191,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans.append(''.join([chr(i) for i in ids])) trans.append(''.join([chr(i) for i in ids]))
return trans return trans
def compute_metrics(self, audio, audio_len, texts, texts_len): def compute_metrics(self, utt, audio, audio_len, texts, texts_len):
cfg = self.config.decoding cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors

@ -51,7 +51,7 @@ class SpeechCollator():
audio_lens = [] audio_lens = []
texts = [] texts = []
text_lens = [] text_lens = []
for audio, text in batch: for utt, audio, text in batch:
# audio # audio
audios.append(audio.T) # [T, D] audios.append(audio.T) # [T, D]
audio_lens.append(audio.shape[1]) audio_lens.append(audio.shape[1])
@ -75,4 +75,4 @@ class SpeechCollator():
padded_texts = pad_sequence( padded_texts = pad_sequence(
texts, padding_value=IGNORE_ID).astype(np.int64) texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64)
return padded_audios, audio_lens, padded_texts, text_lens return utt, padded_audios, audio_lens, padded_texts, text_lens

@ -284,7 +284,7 @@ class ManifestDataset(Dataset):
return self._local_data.tar2object[tarpath].extractfile( return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename]) self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, transcript): def process_utterance(self, utt, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data. """Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file. :param audio_file: Filepath or file object of audio file.
@ -323,7 +323,7 @@ class ManifestDataset(Dataset):
specgram = self._augmentation_pipeline.transform_feature(specgram) specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time feature_aug_time = time.time() - start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}") #logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part return utt, specgram, transcript_part
def _instance_reader_creator(self, manifest): def _instance_reader_creator(self, manifest):
""" """
@ -336,7 +336,9 @@ class ManifestDataset(Dataset):
def reader(): def reader():
for instance in manifest: for instance in manifest:
inst = self.process_utterance(instance["feat"], # inst = self.process_utterance(instance["feat"],
# instance["text"])
inst = self.process_utterance(instance["utt"], instance["feat"],
instance["text"]) instance["text"])
yield inst yield inst
@ -347,4 +349,6 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
instance = self._manifest[idx] instance = self._manifest[idx]
return self.process_utterance(instance["feat"], instance["text"]) return self.process_utterance(instance["utt"], instance["feat"],
instance["text"])
# return self.process_utterance(instance["feat"], instance["text"])

@ -161,7 +161,7 @@ class DeepSpeech2Model(nn.Layer):
reduction=True, # sum reduction=True, # sum
batch_average=True) # sum / batch_size batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len): def forward(self, utt, audio, audio_len, text, text_len):
"""Compute Model loss """Compute Model loss
Args: Args:

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
from typing import List, Union
from pathlib import Path from pathlib import Path
from typing import List
from typing import Union
def erized(syllable: str) -> bool: def erized(syllable: str) -> bool:
@ -67,7 +68,9 @@ def ignore_sandhi(reference: List[str], generated: List[str]) -> List[str]:
return result return result
def convert_transcriptions(reference: Union[str, Path], generated: Union[str, Path], output: Union[str, Path]): def convert_transcriptions(reference: Union[str, Path],
generated: Union[str, Path],
output: Union[str, Path]):
with open(reference, 'rt') as f_ref: with open(reference, 'rt') as f_ref:
with open(generated, 'rt') as f_gen: with open(generated, 'rt') as f_gen:
with open(output, 'wt') as f_out: with open(output, 'wt') as f_out:

@ -1,7 +1,7 @@
dev-clean/ dev-clean
dev-other/ dev-other
test-clean/ test-clean
test-other/ test-other
train-clean-100/ train-clean-100
train-clean-360/ train-clean-360
train-other-500/ train-other-500

@ -11,7 +11,7 @@ avg_num=1
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ###ckpt = deepspeech2
echo "checkpoint name ${ckpt}" echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then

Loading…
Cancel
Save