diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 0976ec1a..7806aaa4 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -444,7 +444,7 @@ class U2Tester(U2Trainer): start_time = time.time() text_feature = self.test_loader.collate_fn.text_feature target_transcripts = self.ordid2token(texts, texts_len) - result_transcripts = self.model.decode( + result_transcripts, result_tokenids = self.model.decode( audio, audio_len, text_feature=text_feature, @@ -462,14 +462,19 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for utt, target, result in zip(utts, target_transcripts, - result_transcripts): + for utt, target, result, rec_tids in zip( + utts, target_transcripts, result_transcripts, result_tokenids): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref num_ins += 1 if fout: - fout.write({"utt": utt, "ref": target, "hyp": result}) + fout.write({ + "utt": utt, + "refs": [target], + "hyps": [result], + "hyps_tokenid": [rec_tids], + }) logger.info(f"Utt: {utt}") logger.info(f"Ref: {target}") logger.info(f"Hyp: {result}") diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index c182c598..18e29b28 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -390,6 +390,10 @@ class U2Tester(U2Trainer): def __init__(self, config, args): super().__init__(config, args) + self.text_feature = TextFeaturizer( + unit_type=self.config.collator.unit_type, + vocab_filepath=self.config.collator.vocab_filepath, + spm_model_prefix=self.config.collator.spm_model_prefix) def id2token(self, texts, texts_len, text_feature): """ ord() id to chr() chr """ @@ -413,15 +417,11 @@ class U2Tester(U2Trainer): error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer start_time = time.time() - text_feature = TextFeaturizer( - unit_type=self.config.collator.unit_type, - vocab_filepath=self.config.collator.vocab_filepath, - spm_model_prefix=self.config.collator.spm_model_prefix) - target_transcripts = self.id2token(texts, texts_len, text_feature) - result_transcripts = self.model.decode( + target_transcripts = self.id2token(texts, texts_len, self.text_feature) + result_transcripts, result_tokenids = self.model.decode( audio, audio_len, - text_feature=text_feature, + text_feature=self.text_feature, decoding_method=cfg.decoding_method, lang_model_path=cfg.lang_model_path, beam_alpha=cfg.alpha, @@ -436,14 +436,19 @@ class U2Tester(U2Trainer): simulate_streaming=cfg.simulate_streaming) decode_time = time.time() - start_time - for utt, target, result in zip(utts, target_transcripts, - result_transcripts): + for i, (utt, target, result, rec_tids) in enumerate(zip( + utts, target_transcripts, result_transcripts, result_tokenids)): errors, len_ref = errors_func(target, result) errors_sum += errors len_refs += len_ref num_ins += 1 if fout: - fout.write({"utt": utt, "ref": target, "hyp": result}) + fout.write({ + "utt": utt, + "refs": [target], + "hyps": [result], + "hyps_tokenid": [rec_tids], + }) logger.info(f"Utt: {utt}") logger.info(f"Ref: {target}") logger.info(f"Hyp: {result}") diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 553ffcb5..ae1feb78 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -32,7 +32,7 @@ __all__ = ["SpeechCollator", "TripletSpeechCollator"] logger = Log(__name__).getlog() -def tokenids(text, keep_transcription_text): +def _tokenids(text, keep_transcription_text): # for training text is token ids tokens = text # token ids @@ -93,6 +93,8 @@ class SpeechCollatorBase(): a user-defined shape) within one batch. """ self.keep_transcription_text = keep_transcription_text + self.train_mode = not keep_transcription_text + self.stride_ms = stride_ms self.window_ms = window_ms self.feat_dim = feat_dim @@ -192,6 +194,7 @@ class SpeechCollatorBase(): texts = [] text_lens = [] utts = [] + tids = [] # tokenids for idx, item in enumerate(batch): utts.append(item['utt']) @@ -203,7 +206,7 @@ class SpeechCollatorBase(): audios.append(audio) # [T, D] audio_lens.append(audio.shape[0]) - tokens = tokenids(text, self.keep_transcription_text) + tokens = _tokenids(text, self.keep_transcription_text) texts.append(tokens) text_lens.append(tokens.shape[0]) diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py index 310f5f58..d8ef9ba6 100644 --- a/deepspeech/io/dataloader.py +++ b/deepspeech/io/dataloader.py @@ -142,6 +142,15 @@ class BatchDataLoader(): collate_fn=batch_collate, num_workers=self.n_iter_processes, ) + def __len__(self): + return len(self.dataloader) + + def __iter__(self): + return self.dataloader.__iter__() + + def __call__(self): + return self.__iter__() + def __repr__(self): echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> " echo += f"train_mode: {self.train_mode}, " @@ -159,12 +168,3 @@ class BatchDataLoader(): echo += f"num_workers: {self.n_iter_processes}, " echo += f"file: {self.json_file}" return echo - - def __len__(self): - return len(self.dataloader) - - def __iter__(self): - return self.dataloader.__iter__() - - def __call__(self): - return self.__iter__() diff --git a/deepspeech/models/u2/u2.py b/deepspeech/models/u2/u2.py index e6cd7b5c..fd63fa9c 100644 --- a/deepspeech/models/u2/u2.py +++ b/deepspeech/models/u2/u2.py @@ -809,7 +809,8 @@ class U2BaseModel(nn.Layer): raise ValueError(f"Not support decoding method: {decoding_method}") res = [text_feature.defeaturize(hyp) for hyp in hyps] - return res + res_tokenids = [hyp for hyp in hyps] + return res, res_tokenids class U2Model(U2BaseModel): diff --git a/examples/librispeech/s1/conf/augmentation.json b/examples/librispeech/s1/conf/augmentation.json index 40a5b790..31c481c8 100644 --- a/examples/librispeech/s1/conf/augmentation.json +++ b/examples/librispeech/s1/conf/augmentation.json @@ -1,12 +1,4 @@ [ - { - "type": "shift", - "params": { - "min_shift_ms": -5, - "max_shift_ms": 5 - }, - "prob": 1.0 - }, { "type": "speed", "params": { @@ -16,6 +8,14 @@ }, "prob": 0.0 }, + { + "type": "shift", + "params": { + "min_shift_ms": -5, + "max_shift_ms": 5 + }, + "prob": 1.0 + }, { "type": "specaug", "params": { diff --git a/examples/librispeech/s2/README.md b/examples/librispeech/s2/README.md index e4022f01..34c65c11 100644 --- a/examples/librispeech/s2/README.md +++ b/examples/librispeech/s2/README.md @@ -1,41 +1,9 @@ # LibriSpeech -## Data -| Data Subset | Duration in Seconds | -| data/manifest.train | 0.83s ~ 29.735s | -| data/manifest.dev | 1.065 ~ 35.155s | -| data/manifest.test-clean | 1.285s ~ 34.955s | - -## Conformer -| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | - | - | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | | | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | | | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | | | - -### Test w/o length filter -| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean-all | attention | | | - - -## Chunk Conformer - -| Model | Params | Config | Augmentation| Test set | Decode method | Chunk Size & Left Chunks | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention | 16, -1 | | | -| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 16, -1 | | | -| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 16, -1 | | - | -| conformer | 47.63 M | conf/chunk_conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 16, -1 | | - | - - ## Transformer -| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | | | - -### Test w/o length filter -| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | +| Model | Params | Config | Augmentation| Test Set | Decode Method | Loss | WER % | | --- | --- | --- | --- | --- | --- | --- | --- | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean-all | attention | | | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention | 6.395054340362549 | 4.2 | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_greedy_search | 6.395054340362549 | 5.0 | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | ctc_prefix_beam_search | 6.395054340362549 | | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug | test-clean | attention_rescore | 6.395054340362549 | | diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index b86224ff..c9eed4f9 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -5,9 +5,9 @@ data: test_manifest: data/manifest.test-clean collator: - vocab_filepath: data/train_960_unigram5000_units.txt - unit_type: 'spm' - spm_model_prefix: 'data/train_960_unigram5000' + vocab_filepath: data/bpe_unigram_5000_units.txt + unit_type: spm + spm_model_prefix: data/bpe_unigram_5000 feat_dim: 83 stride_ms: 10.0 window_ms: 25.0 diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 5eeb2d61..67174152 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -46,15 +46,17 @@ pids=() # initialize pids for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do ( + echo "${dmethd} decoding" for rtask in ${recog_set}; do ( - decode_dir=decode_${rtask}_${dmethd}_$(basename ${config_path%.*})_${lmtag} + echo "${rtask} dataset" + decode_dir=decode_${rtask/-/_}_${dmethd}_$(basename ${config_path%.*})_${lmtag} feat_recog_dir=${datadir} mkdir -p ${expdir}/${decode_dir} mkdir -p ${feat_recog_dir} # split data - split_json.sh ${feat_recog_dir}/manifest.${rtask} ${nj} + split_json.sh manifest.${rtask} ${nj} #### use CPU for decoding ngpu=0 @@ -74,17 +76,16 @@ for dmethd in attention ctc_greedy_search ctc_prefix_beam_search attention_resco --opts decoding.batch_size ${batch_size} \ --opts data.test_manifest ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} - score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel}.model --wer true ${expdir}/${decode_dir} ${dict} + score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${expdir}/${decode_dir} ${dict} ) & pids+=($!) # store background pids + i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done + [ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false done -) & -pids+=($!) # store background pids +) done -i=0; for pid in "${pids[@]}"; do wait ${pid} || ((++i)); done -[ ${i} -gt 0 ] && echo "$0: ${i} background jobs are failed." && false echo "Finished" exit 0 diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index 46b6ac1b..8a219381 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -32,7 +32,7 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then diff --git a/tools/Makefile b/tools/Makefile index b8b00293..5690ea91 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -6,7 +6,7 @@ CC ?= gcc # used for sph2pipe # CXX = clang++ # Uncomment these lines... # CC = clang # ...to build with Clang. -WGET ?= wget +WGET ?= wget --no-check-certificate .PHONY: all clean diff --git a/utils/json2trn.py b/utils/json2trn.py new file mode 100755 index 00000000..873fde4f --- /dev/null +++ b/utils/json2trn.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# 2018 Xuankai Chang (Shanghai Jiao Tong University) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +import argparse +import json +import logging +import sys + +import jsonlines +from utility import get_commandline_args + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert a json to a transcription file with a token dictionary", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument("json", type=str, help="jsonlines files") + parser.add_argument("dict", type=str, help="dict, not used.") + parser.add_argument( + "--num-spkrs", type=int, default=1, help="number of speakers") + parser.add_argument( + "--refs", type=str, nargs="+", help="ref for all speakers") + parser.add_argument( + "--hyps", type=str, nargs="+", help="hyp for all outputs") + return parser + + +def main(args): + args = get_parser().parse_args(args) + convert(args.json, args.dict, args.refs, args.hyps, args.num_spkrs) + + +def convert(jsonf, dic, refs, hyps, num_spkrs=1): + n_ref = len(refs) + n_hyp = len(hyps) + assert n_ref == n_hyp + assert n_ref == num_spkrs + + # logging info + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + logging.basicConfig(level=logging.INFO, format=logfmt) + logging.info(get_commandline_args()) + + logging.info("reading %s", jsonf) + with jsonlines.open(jsonf, "r") as f: + j = [item for item in f] + + logging.info("reading %s", dic) + with open(dic, "r") as f: + dictionary = f.readlines() + char_list = [entry.split(" ")[0] for entry in dictionary] + char_list.insert(0, "") + char_list.append("") + + for ns in range(num_spkrs): + hyp_file = open(hyps[ns], "w") + ref_file = open(refs[ns], "w") + + for x in j: + # recognition hypothesis + if num_spkrs == 1: + #seq = [char_list[int(i)] for i in x['hyps_tokenid'][0]] + seq = x['hyps'][0] + else: + seq = [char_list[int(i)] for i in x['hyps_tokenid'][ns]] + # In the recognition hypothesis, + # the symbol is usually attached in the last part of the sentence + # and it is removed below. + #hyp_file.write(" ".join(seq).replace("", "")) + hyp_file.write(seq.replace("", "")) + # spk-uttid + hyp_file.write(" (" + x["utt"] + ")\n") + + # reference + if num_spkrs == 1: + seq = x["refs"][0] + else: + seq = x['refs'][ns] + # Unlike the recognition hypothesis, + # the reference is directly generated from a token without dictionary + # to avoid to include symbols in the reference to make scoring normal. + # The detailed discussion can be found at + # https://github.com/espnet/espnet/issues/993 + # ref_file.write( + # seq + " (" + j["utts"][x]["utt2spk"].replace("-", "_") + "-" + x + ")\n" + # ) + ref_file.write(seq + " (" + x['utt'] + ")\n") + + hyp_file.close() + ref_file.close() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/utils/score_sclite.sh b/utils/score_sclite.sh index 7ded76eb..99214b7d 100755 --- a/utils/score_sclite.sh +++ b/utils/score_sclite.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash +set -e + # Copyright 2017 Johns Hopkins University (Shinji Watanabe) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) diff --git a/utils/utility.py b/utils/utility.py index a6b81d73..b4db518a 100755 --- a/utils/utility.py +++ b/utils/utility.py @@ -14,6 +14,7 @@ import hashlib import json import os +import sys import tarfile import zipfile from typing import Text @@ -21,7 +22,7 @@ from typing import Text __all__ = [ "check_md5sum", "getfile_insensitive", "download_multi", "download", "unpack", "unzip", "md5file", "print_arguments", "add_arguments", - "read_manifest" + "read_manifest", "get_commandline_args" ] @@ -46,6 +47,40 @@ def read_manifest(manifest_path): return manifest +def get_commandline_args(): + extra_chars = [ + " ", + ";", + "&", + "(", + ")", + "|", + "^", + "<", + ">", + "?", + "*", + "[", + "]", + "$", + "`", + '"', + "\\", + "!", + "{", + "}", + ] + + # Escape the extra characters for shell + argv = [ + arg.replace("'", "'\\''") if all(char not in arg + for char in extra_chars) else + "'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv + ] + + return sys.executable + " " + " ".join(argv) + + def print_arguments(args, info=None): """Print argparse's arguments.