From c8368410e291f5ef0992309ed0fd19fc9f59865b Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Fri, 4 Jun 2021 12:41:48 +0000 Subject: [PATCH 1/5] utt datapipeline --- deepspeech/exps/deepspeech2/model.py | 4 ++-- deepspeech/io/collator.py | 4 ++-- deepspeech/io/dataset.py | 12 ++++++++---- deepspeech/models/deepspeech2.py | 2 +- examples/chinese_g2p/local/ignore_sandhi.py | 7 +++++-- examples/dataset/librispeech/.gitignore | 14 +++++++------- examples/librispeech/s0/README.md | 2 +- examples/tiny/s0/run.sh | 2 +- 8 files changed, 27 insertions(+), 20 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 8e8a18245..05b55f750 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -75,7 +75,7 @@ class DeepSpeech2Trainer(Trainer): for i, batch in enumerate(self.valid_loader): loss = self.model(*batch) if paddle.isfinite(loss): - num_utts = batch[0].shape[0] + num_utts = batch[1].shape[0] num_seen_utts += num_utts total_loss += float(loss) * num_utts valid_losses['val_loss'].append(float(loss)) @@ -191,7 +191,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, audio, audio_len, texts, texts_len): + def compute_metrics(self, utt, audio, audio_len, texts, texts_len): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 7f019039c..5b521fbd5 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -51,7 +51,7 @@ class SpeechCollator(): audio_lens = [] texts = [] text_lens = [] - for audio, text in batch: + for utt, audio, text in batch: # audio audios.append(audio.T) # [T, D] audio_lens.append(audio.shape[1]) @@ -75,4 +75,4 @@ class SpeechCollator(): padded_texts = pad_sequence( texts, padding_value=IGNORE_ID).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64) - return padded_audios, audio_lens, padded_texts, text_lens + return utt, padded_audios, audio_lens, padded_texts, text_lens diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index fba5f7c66..eaa57a4ec 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -284,7 +284,7 @@ class ManifestDataset(Dataset): return self._local_data.tar2object[tarpath].extractfile( self._local_data.tar2info[tarpath][filename]) - def process_utterance(self, audio_file, transcript): + def process_utterance(self, utt, audio_file, transcript): """Load, augment, featurize and normalize for speech data. :param audio_file: Filepath or file object of audio file. @@ -323,7 +323,7 @@ class ManifestDataset(Dataset): specgram = self._augmentation_pipeline.transform_feature(specgram) feature_aug_time = time.time() - start_time #logger.debug(f"audio feature augmentation time: {feature_aug_time}") - return specgram, transcript_part + return utt, specgram, transcript_part def _instance_reader_creator(self, manifest): """ @@ -336,7 +336,9 @@ class ManifestDataset(Dataset): def reader(): for instance in manifest: - inst = self.process_utterance(instance["feat"], + # inst = self.process_utterance(instance["feat"], + # instance["text"]) + inst = self.process_utterance(instance["utt"], instance["feat"], instance["text"]) yield inst @@ -347,4 +349,6 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] - return self.process_utterance(instance["feat"], instance["text"]) + return self.process_utterance(instance["utt"], instance["feat"], + instance["text"]) + # return self.process_utterance(instance["feat"], instance["text"]) diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 0ff5514de..ab617a534 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -161,7 +161,7 @@ class DeepSpeech2Model(nn.Layer): reduction=True, # sum batch_average=True) # sum / batch_size - def forward(self, audio, audio_len, text, text_len): + def forward(self, utt, audio, audio_len, text, text_len): """Compute Model loss Args: diff --git a/examples/chinese_g2p/local/ignore_sandhi.py b/examples/chinese_g2p/local/ignore_sandhi.py index cda1bd145..b7f37a272 100644 --- a/examples/chinese_g2p/local/ignore_sandhi.py +++ b/examples/chinese_g2p/local/ignore_sandhi.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse -from typing import List, Union from pathlib import Path +from typing import List +from typing import Union def erized(syllable: str) -> bool: @@ -67,7 +68,9 @@ def ignore_sandhi(reference: List[str], generated: List[str]) -> List[str]: return result -def convert_transcriptions(reference: Union[str, Path], generated: Union[str, Path], output: Union[str, Path]): +def convert_transcriptions(reference: Union[str, Path], + generated: Union[str, Path], + output: Union[str, Path]): with open(reference, 'rt') as f_ref: with open(generated, 'rt') as f_gen: with open(output, 'wt') as f_out: diff --git a/examples/dataset/librispeech/.gitignore b/examples/dataset/librispeech/.gitignore index a8d8eb76d..dfd5c67b5 100644 --- a/examples/dataset/librispeech/.gitignore +++ b/examples/dataset/librispeech/.gitignore @@ -1,7 +1,7 @@ -dev-clean/ -dev-other/ -test-clean/ -test-other/ -train-clean-100/ -train-clean-360/ -train-other-500/ +dev-clean +dev-other +test-clean +test-other +train-clean-100 +train-clean-360 +train-other-500 diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 09f700da8..393dd4579 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -3,7 +3,7 @@ ## Deepspeech2 | Model | release | Config | Test set | Loss | WER | -| --- | --- | --- | --- | --- | --- | +| --- | --- | --- | --- | --- | --- | | DeepSpeech2 | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | | DeepSpeech2 | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | | DeepSpeech2 | 1.8.5 | - | test-clean | - | 0.074939 | diff --git a/examples/tiny/s0/run.sh b/examples/tiny/s0/run.sh index d4961adb2..0f2e3fd18 100755 --- a/examples/tiny/s0/run.sh +++ b/examples/tiny/s0/run.sh @@ -11,7 +11,7 @@ avg_num=1 source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; avg_ckpt=avg_${avg_num} -ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') ###ckpt = deepspeech2 echo "checkpoint name ${ckpt}" if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then From f3c9f32c9a7782190540cbd9917921a6ff251643 Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Mon, 7 Jun 2021 12:45:39 +0000 Subject: [PATCH 2/5] add utt to train and test 0607 --- deepspeech/exps/deepspeech2/model.py | 11 +++++++---- deepspeech/models/deepspeech2.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 05b55f750..ce8b56ac6 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -43,7 +43,8 @@ class DeepSpeech2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): start = time.time() - loss = self.model(*batch_data) + utt, audio, audio_len, text, text_len = batch_data + loss = self.model(audio, audio_len, text, text_len) loss.backward() layer_tools.print_grads(self.model, print_func=None) self.optimizer.step() @@ -73,7 +74,8 @@ class DeepSpeech2Trainer(Trainer): num_seen_utts = 1 total_loss = 0.0 for i, batch in enumerate(self.valid_loader): - loss = self.model(*batch) + utt, audio, audio_len, text, text_len = batch + loss = self.model(audio, audio_len, text, text_len) if paddle.isfinite(loss): num_utts = batch[1].shape[0] num_seen_utts += num_utts @@ -191,7 +193,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utt, audio, audio_len, texts, texts_len): + def compute_metrics(self, audio, audio_len, texts, texts_len): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -240,7 +242,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): errors_sum, len_refs, num_ins = 0.0, 0, 0 for i, batch in enumerate(self.test_loader): - metrics = self.compute_metrics(*batch) + utt, audio, audio_len, texts, texts_len = batch + metrics = self.compute_metrics(audio, audio_len, texts, texts_len) errors_sum += metrics['errors_sum'] len_refs += metrics['len_refs'] num_ins += metrics['num_ins'] diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index ab617a534..0ff5514de 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -161,7 +161,7 @@ class DeepSpeech2Model(nn.Layer): reduction=True, # sum batch_average=True) # sum / batch_size - def forward(self, utt, audio, audio_len, text, text_len): + def forward(self, audio, audio_len, text, text_len): """Compute Model loss Args: From a58b1cb30ad1a7281c75e7e04abca26b68c0c61c Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Tue, 8 Jun 2021 07:06:18 +0000 Subject: [PATCH 3/5] add result output --- deepspeech/exps/deepspeech2/model.py | 26 ++++++++++++++------------ deepspeech/exps/u2/model.py | 17 +++++++++++------ deepspeech/io/collator.py | 5 ++++- deepspeech/models/u2.py | 1 + deepspeech/modules/conv.py | 3 ++- examples/tiny/s0/run.sh | 2 +- examples/tiny/s1/conf/transformer.yaml | 8 ++++---- examples/tiny/s1/run.sh | 10 ++++++---- 8 files changed, 43 insertions(+), 29 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index ce8b56ac6..468bc6521 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -193,7 +193,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, audio, audio_len, texts, texts_len): + def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout = None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -215,11 +215,13 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cutoff_top_n=cfg.cutoff_top_n, num_processes=cfg.num_proc_bsearch) - for target, result in zip(target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref num_ins += 1 + if fout: + fout.write(utt + " " + result + "\n") logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) logger.info("Current error rate [%s] = %f" % @@ -240,16 +242,16 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): cfg = self.config error_rate_type = None errors_sum, len_refs, num_ins = 0.0, 0, 0 - - for i, batch in enumerate(self.test_loader): - utt, audio, audio_len, texts, texts_len = batch - metrics = self.compute_metrics(audio, audio_len, texts, texts_len) - errors_sum += metrics['errors_sum'] - len_refs += metrics['len_refs'] - num_ins += metrics['num_ins'] - error_rate_type = metrics['error_rate_type'] - logger.info("Error rate [%s] (%d/?) = %f" % - (error_rate_type, num_ins, errors_sum / len_refs)) + with open(self.args.result_file, 'w') as fout: + for i, batch in enumerate(self.test_loader): + utts, audio, audio_len, texts, texts_len = batch + metrics = self.compute_metrics(utts, audio, audio_len, texts, texts_len, fout) + errors_sum += metrics['errors_sum'] + len_refs += metrics['len_refs'] + num_ins += metrics['num_ins'] + error_rate_type = metrics['error_rate_type'] + logger.info("Error rate [%s] (%d/?) = %f" % + (error_rate_type, num_ins, errors_sum / len_refs)) # logging msg = "Test: " diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index f166a071e..8837444d6 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -76,8 +76,9 @@ class U2Trainer(Trainer): def train_batch(self, batch_index, batch_data, msg): train_conf = self.config.training start = time.time() + utt, audio, audio_len, text, text_len = batch_data - loss, attention_loss, ctc_loss = self.model(*batch_data) + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad loss.backward() @@ -119,9 +120,10 @@ class U2Trainer(Trainer): num_seen_utts = 1 total_loss = 0.0 for i, batch in enumerate(self.valid_loader): - loss, attention_loss, ctc_loss = self.model(*batch) + utt, audio, audio_len, text, text_len = batch + loss, attention_loss, ctc_loss = self.model(audio, audio_len, text, text_len) if paddle.isfinite(loss): - num_utts = batch[0].shape[0] + num_utts = batch[1].shape[0] num_seen_utts += num_utts total_loss += float(loss) * num_utts valid_losses['val_loss'].append(float(loss)) @@ -366,7 +368,7 @@ class U2Tester(U2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, audio, audio_len, texts, texts_len, fout=None): + def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None, fref=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -393,13 +395,15 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for target, result in zip(target_transcripts, result_transcripts): + for utt, target, result in zip(utts, target_transcripts, result_transcripts): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref num_ins += 1 if fout: - fout.write(result + "\n") + fout.write(utt + " " + result + "\n") + if fref: + fref.write(utt + " " + target + "\n") logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) logger.info("One example error rate [%s] = %f" % @@ -428,6 +432,7 @@ class U2Tester(U2Trainer): num_time = 0.0 with open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): + # utt, audio, audio_len, text, text_len = batch metrics = self.compute_metrics(*batch, fout=fout) num_frames += metrics['num_frames'] num_time += metrics["decode_time"] diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 5b521fbd5..3bec9875f 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -51,7 +51,10 @@ class SpeechCollator(): audio_lens = [] texts = [] text_lens = [] + utts = [] for utt, audio, text in batch: + #utt + utts.append(utt) # audio audios.append(audio.T) # [T, D] audio_lens.append(audio.shape[1]) @@ -75,4 +78,4 @@ class SpeechCollator(): padded_texts = pad_sequence( texts, padding_value=IGNORE_ID).astype(np.int64) text_lens = np.array(text_lens).astype(np.int64) - return utt, padded_audios, audio_lens, padded_texts, text_lens + return utts, padded_audios, audio_lens, padded_texts, text_lens diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 238e2d35c..bcfddaef0 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -905,6 +905,7 @@ class U2InferModel(U2Model): def __init__(self, configs: dict): super().__init__(configs) + def forward(self, feats, feats_lengths, diff --git a/deepspeech/modules/conv.py b/deepspeech/modules/conv.py index 111f5d3b4..c17d59a1b 100644 --- a/deepspeech/modules/conv.py +++ b/deepspeech/modules/conv.py @@ -114,7 +114,8 @@ class ConvBn(nn.Layer): masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] # TODO(Hui Zhang): not support bool multiply - masks = masks.type_as(x) + # masks = masks.type_as(x) + masks = masks.astype(x) x = x.multiply(masks) return x, x_len diff --git a/examples/tiny/s0/run.sh b/examples/tiny/s0/run.sh index 0f2e3fd18..d7e153e8d 100755 --- a/examples/tiny/s0/run.sh +++ b/examples/tiny/s0/run.sh @@ -26,7 +26,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 35c11731c..dd3e02677 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -8,7 +8,7 @@ data: spm_model_prefix: 'data/bpe_unigram_200' mean_std_filepath: "" augmentation_config: conf/augmentation.json - batch_size: 4 + batch_size: 2 #4 min_input_len: 0.5 # second max_input_len: 20.0 # second min_output_len: 0.0 # tokens @@ -31,7 +31,7 @@ data: keep_transcription_text: False sortagrad: True shuffle_method: batch_shuffle - num_workers: 2 + num_workers: 0 #2 # network architecture @@ -70,7 +70,7 @@ model: training: - n_epoch: 20 + n_epoch: 2 accum_grad: 1 global_grad_clip: 5.0 optim: adam @@ -85,7 +85,7 @@ training: decoding: - batch_size: 64 + batch_size: 8 #64 error_rate_type: wer decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm diff --git a/examples/tiny/s1/run.sh b/examples/tiny/s1/run.sh index f7e41a338..fdcf7ff01 100755 --- a/examples/tiny/s1/run.sh +++ b/examples/tiny/s1/run.sh @@ -20,20 +20,22 @@ fi if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # train model, all `ckpt` under `exp` dir - CUDA_VISIBLE_DEVICES=4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} + ./local/train.sh ${conf_path} ${ckpt} fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + # CUDA_VISIBLE_DEVICES=7 + ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + # CUDA_VISIBLE_DEVICES= + ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi From 8781ab58cf3e6316f2be4c018556652882634c5a Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Tue, 8 Jun 2021 07:15:59 +0000 Subject: [PATCH 4/5] fix export and run.sh --- deepspeech/modules/conv.py | 2 +- deepspeech/modules/rnn.py | 2 +- examples/aishell/s0/run.sh | 2 +- examples/aishell/s1/run.sh | 2 +- examples/librispeech/s0/run.sh | 2 +- examples/librispeech/s1/run.sh | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/deepspeech/modules/conv.py b/deepspeech/modules/conv.py index c17d59a1b..8bf48b2c8 100644 --- a/deepspeech/modules/conv.py +++ b/deepspeech/modules/conv.py @@ -115,7 +115,7 @@ class ConvBn(nn.Layer): masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] # TODO(Hui Zhang): not support bool multiply # masks = masks.type_as(x) - masks = masks.astype(x) + masks = masks.astype(x.dtype) x = x.multiply(masks) return x, x_len diff --git a/deepspeech/modules/rnn.py b/deepspeech/modules/rnn.py index 29bd28839..01b55c4a2 100644 --- a/deepspeech/modules/rnn.py +++ b/deepspeech/modules/rnn.py @@ -309,6 +309,6 @@ class RNNStack(nn.Layer): masks = make_non_pad_mask(x_len) #[B, T] masks = masks.unsqueeze(-1) # [B, T, 1] # TODO(Hui Zhang): not support bool multiply - masks = masks.type_as(x) + masks = masks.astype(x.dtype) x = x.multiply(masks) return x, x_len diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index d4961adb2..4073c81b9 100755 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -26,7 +26,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/aishell/s1/run.sh b/examples/aishell/s1/run.sh index 016502298..4cf09553b 100644 --- a/examples/aishell/s1/run.sh +++ b/examples/aishell/s1/run.sh @@ -25,7 +25,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/librispeech/s0/run.sh b/examples/librispeech/s0/run.sh index 3e536bd79..6553e073d 100755 --- a/examples/librispeech/s0/run.sh +++ b/examples/librispeech/s0/run.sh @@ -24,7 +24,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then diff --git a/examples/librispeech/s1/run.sh b/examples/librispeech/s1/run.sh index 472e6ebfb..65194d902 100755 --- a/examples/librispeech/s1/run.sh +++ b/examples/librispeech/s1/run.sh @@ -24,7 +24,7 @@ fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # avg n best model - ./local/avg.sh exp/${ckpt}/checkpoints ${avg_num} + avg.sh exp/${ckpt}/checkpoints ${avg_num} fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then From b4bda290aaef47bef700f1d61c12f626e869e8ce Mon Sep 17 00:00:00 2001 From: Haoxin Ma <745165806@qq.com> Date: Wed, 9 Jun 2021 13:29:35 +0000 Subject: [PATCH 5/5] fix bugs --- deepspeech/exps/u2/model.py | 5 +---- deepspeech/io/dataset.py | 12 +++++------- examples/tiny/s1/conf/transformer.yaml | 4 ++-- examples/tiny/s1/run.sh | 6 ++---- 4 files changed, 10 insertions(+), 17 deletions(-) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 8837444d6..334d6bc8e 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -368,7 +368,7 @@ class U2Tester(U2Trainer): trans.append(''.join([chr(i) for i in ids])) return trans - def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None, fref=None): + def compute_metrics(self, utts, audio, audio_len, texts, texts_len, fout=None): cfg = self.config.decoding errors_sum, len_refs, num_ins = 0.0, 0, 0 errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors @@ -402,8 +402,6 @@ class U2Tester(U2Trainer): num_ins += 1 if fout: fout.write(utt + " " + result + "\n") - if fref: - fref.write(utt + " " + target + "\n") logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" % (target, result)) logger.info("One example error rate [%s] = %f" % @@ -432,7 +430,6 @@ class U2Tester(U2Trainer): num_time = 0.0 with open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): - # utt, audio, audio_len, text, text_len = batch metrics = self.compute_metrics(*batch, fout=fout) num_frames += metrics['num_frames'] num_time += metrics["decode_time"] diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index eaa57a4ec..1cf3827d3 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -284,7 +284,7 @@ class ManifestDataset(Dataset): return self._local_data.tar2object[tarpath].extractfile( self._local_data.tar2info[tarpath][filename]) - def process_utterance(self, utt, audio_file, transcript): + def process_utterance(self, audio_file, transcript): """Load, augment, featurize and normalize for speech data. :param audio_file: Filepath or file object of audio file. @@ -323,7 +323,7 @@ class ManifestDataset(Dataset): specgram = self._augmentation_pipeline.transform_feature(specgram) feature_aug_time = time.time() - start_time #logger.debug(f"audio feature augmentation time: {feature_aug_time}") - return utt, specgram, transcript_part + return specgram, transcript_part def _instance_reader_creator(self, manifest): """ @@ -336,9 +336,7 @@ class ManifestDataset(Dataset): def reader(): for instance in manifest: - # inst = self.process_utterance(instance["feat"], - # instance["text"]) - inst = self.process_utterance(instance["utt"], instance["feat"], + inst = self.process_utterance(instance["feat"], instance["text"]) yield inst @@ -349,6 +347,6 @@ class ManifestDataset(Dataset): def __getitem__(self, idx): instance = self._manifest[idx] - return self.process_utterance(instance["utt"], instance["feat"], + feat, text =self.process_utterance(instance["feat"], instance["text"]) - # return self.process_utterance(instance["feat"], instance["text"]) + return instance["utt"], feat, text diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index dd3e02677..0a7cf3be8 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -8,7 +8,7 @@ data: spm_model_prefix: 'data/bpe_unigram_200' mean_std_filepath: "" augmentation_config: conf/augmentation.json - batch_size: 2 #4 + batch_size: 4 min_input_len: 0.5 # second max_input_len: 20.0 # second min_output_len: 0.0 # tokens @@ -31,7 +31,7 @@ data: keep_transcription_text: False sortagrad: True shuffle_method: batch_shuffle - num_workers: 0 #2 + num_workers: 2 # network architecture diff --git a/examples/tiny/s1/run.sh b/examples/tiny/s1/run.sh index fdcf7ff01..b148869b7 100755 --- a/examples/tiny/s1/run.sh +++ b/examples/tiny/s1/run.sh @@ -30,12 +30,10 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - # CUDA_VISIBLE_DEVICES=7 - ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=7 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # export ckpt avg_n - # CUDA_VISIBLE_DEVICES= - ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi