utt datapipeline

pull/657/head
Haoxin Ma 4 years ago
parent 03e5a64d26
commit c8368410e2

@ -75,7 +75,7 @@ class DeepSpeech2Trainer(Trainer):
for i, batch in enumerate(self.valid_loader):
loss = self.model(*batch)
if paddle.isfinite(loss):
num_utts = batch[0].shape[0]
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
@ -191,7 +191,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
trans.append(''.join([chr(i) for i in ids]))
return trans
def compute_metrics(self, audio, audio_len, texts, texts_len):
def compute_metrics(self, utt, audio, audio_len, texts, texts_len):
cfg = self.config.decoding
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors

@ -51,7 +51,7 @@ class SpeechCollator():
audio_lens = []
texts = []
text_lens = []
for audio, text in batch:
for utt, audio, text in batch:
# audio
audios.append(audio.T) # [T, D]
audio_lens.append(audio.shape[1])
@ -75,4 +75,4 @@ class SpeechCollator():
padded_texts = pad_sequence(
texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64)
return padded_audios, audio_lens, padded_texts, text_lens
return utt, padded_audios, audio_lens, padded_texts, text_lens

@ -284,7 +284,7 @@ class ManifestDataset(Dataset):
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, transcript):
def process_utterance(self, utt, audio_file, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
@ -323,7 +323,7 @@ class ManifestDataset(Dataset):
specgram = self._augmentation_pipeline.transform_feature(specgram)
feature_aug_time = time.time() - start_time
#logger.debug(f"audio feature augmentation time: {feature_aug_time}")
return specgram, transcript_part
return utt, specgram, transcript_part
def _instance_reader_creator(self, manifest):
"""
@ -336,7 +336,9 @@ class ManifestDataset(Dataset):
def reader():
for instance in manifest:
inst = self.process_utterance(instance["feat"],
# inst = self.process_utterance(instance["feat"],
# instance["text"])
inst = self.process_utterance(instance["utt"], instance["feat"],
instance["text"])
yield inst
@ -347,4 +349,6 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx):
instance = self._manifest[idx]
return self.process_utterance(instance["feat"], instance["text"])
return self.process_utterance(instance["utt"], instance["feat"],
instance["text"])
# return self.process_utterance(instance["feat"], instance["text"])

@ -161,7 +161,7 @@ class DeepSpeech2Model(nn.Layer):
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len):
def forward(self, utt, audio, audio_len, text, text_len):
"""Compute Model loss
Args:

@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from typing import List, Union
from pathlib import Path
from typing import List
from typing import Union
def erized(syllable: str) -> bool:
@ -67,7 +68,9 @@ def ignore_sandhi(reference: List[str], generated: List[str]) -> List[str]:
return result
def convert_transcriptions(reference: Union[str, Path], generated: Union[str, Path], output: Union[str, Path]):
def convert_transcriptions(reference: Union[str, Path],
generated: Union[str, Path],
output: Union[str, Path]):
with open(reference, 'rt') as f_ref:
with open(generated, 'rt') as f_gen:
with open(output, 'wt') as f_out:

@ -1,7 +1,7 @@
dev-clean/
dev-other/
test-clean/
test-other/
train-clean-100/
train-clean-360/
train-other-500/
dev-clean
dev-other
test-clean
test-other
train-clean-100
train-clean-360
train-other-500

@ -11,7 +11,7 @@ avg_num=1
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ###ckpt = deepspeech2
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then

Loading…
Cancel
Save