From 9626e99ce45ab39c7de7b46ddd43112d45a1211d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 6 Apr 2021 09:36:30 +0000 Subject: [PATCH] add utils --- tests/deepspeech2_model_test.py | 7 +- tests/u2_model_test.py | 45 ++++------ utils/combine_data.sh | 146 ++++++++++++++++++++++++++++++++ utils/parse_options.sh | 97 +++++++++++++++++++++ utils/spm_decode | 49 +++++++++++ utils/spm_encode | 99 ++++++++++++++++++++++ utils/spm_train | 13 +++ utils/sym2int.pl | 104 +++++++++++++++++++++++ 8 files changed, 529 insertions(+), 31 deletions(-) create mode 100644 utils/combine_data.sh create mode 100644 utils/parse_options.sh create mode 100755 utils/spm_decode create mode 100755 utils/spm_encode create mode 100755 utils/spm_train create mode 100644 utils/sym2int.pl diff --git a/tests/deepspeech2_model_test.py b/tests/deepspeech2_model_test.py index 8ada42c35..5400e6dbc 100644 --- a/tests/deepspeech2_model_test.py +++ b/tests/deepspeech2_model_test.py @@ -28,12 +28,11 @@ class TestDeepSpeech2Model(unittest.TestCase): #(B, T, D) audio = np.random.randn(self.batch_size, max_len, self.feat_dim) - audio_len = np.random.randint( - max_len, size=self.batch_size, dtype='int32') + audio_len = np.random.randint(max_len, size=self.batch_size) audio_len[-1] = max_len #(B, U) - text = np.array([[1, 2], [1, 2]], dtype='int32') - text_len = np.array([2] * self.batch_size, dtype='int32') + text = np.array([[1, 2], [1, 2]]) + text_len = np.array([2] * self.batch_size) self.audio = paddle.to_tensor(audio, dtype='float32') self.audio_len = paddle.to_tensor(audio_len, dtype='int64') diff --git a/tests/u2_model_test.py b/tests/u2_model_test.py index e2230a394..a86210750 100644 --- a/tests/u2_model_test.py +++ b/tests/u2_model_test.py @@ -21,39 +21,30 @@ from deepspeech.models.u2 import U2ConformerModel class TestU2Model(unittest.TestCase): def setUp(self): - batch_size = 2 - feat_dim = 161 - max_len = 100 - audio = np.random.randn(batch_size, feat_dim, max_len) - audio_len = np.random.randint(100, size=batch_size, dtype='int32') - audio_len[-1] = 100 - text = np.array([[1, 2], [1, 2]], dtype='int32') - text_len = np.array([2] * batch_size, dtype='int32') + paddle.set_device('cpu') + + self.batch_size = 2 + self.feat_dim = 161 + self.max_len = 64 + + #(B, T, D) + audio = np.random.randn(self.batch_size, self.max_len, self.feat_dim) + audio_len = np.random.randint(self.max_len, size=self.batch_size) + audio_len[-1] = self.max_len + #(B, U) + text = np.array([[1, 2], [1, 2]]) + text_len = np.array([2] * self.batch_size) self.audio = paddle.to_tensor(audio, dtype='float32') self.audio_len = paddle.to_tensor(audio_len, dtype='int64') self.text = paddle.to_tensor(text, dtype='int32') self.text_len = paddle.to_tensor(text_len, dtype='int64') - print(audio.shape) - print(audio_len.shape) - print(text.shape) - print(text_len.shape) - print("-----------------") - - def test_ds2_1(self): - model = DeepSpeech2Model( - feat_size=feat_dim, - dict_size=10, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=1024, - use_gru=False, - share_rnn_weights=False, ) - logits, probs, logits_len = model(self.audio, self.audio_len, self.text, - self.text_len) - print('probs.shape', probs.shape) - print("-----------------") + def test_transformer(self): + model = U2TransformerModel() + + def test_conformer(self): + model = U2ConformerModel() if __name__ == '__main__': diff --git a/utils/combine_data.sh b/utils/combine_data.sh new file mode 100644 index 000000000..7a217bad8 --- /dev/null +++ b/utils/combine_data.sh @@ -0,0 +1,146 @@ +#!/bin/bash +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0. +# 2014 David Snyder + +# This script combines the data from multiple source directories into +# a single destination directory. + +# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information +# about what these directories contain. + +# Begin configuration section. +extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..." +skip_fix=false # skip the fix_data_dir.sh in the end +# End configuration section. + +echo "$0 $@" # Print the command line for logging + +if [ -f path.sh ]; then . ./path.sh; fi +if [ -f parse_options.sh ]; then . parse_options.sh || exit 1; fi + +if [ $# -lt 2 ]; then + echo "Usage: combine_data.sh [--extra-files 'file1 file2'] ..." + echo "Note, files that don't appear in all source dirs will not be combined," + echo "with the exception of utt2uniq and segments, which are created where necessary." + exit 1 +fi + +dest=$1; +shift; + +first_src=$1; + +rm -r $dest 2>/dev/null +mkdir -p $dest; + +export LC_ALL=C + +for dir in $*; do + if [ ! -f $dir/utt2spk ]; then + echo "$0: no such file $dir/utt2spk" + exit 1; + fi +done + +# Check that frame_shift are compatible, where present together with features. +dir_with_frame_shift= +for dir in $*; do + if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then + if [[ $dir_with_frame_shift ]] && + ! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then + echo "$0:error: different frame_shift in directories $dir and " \ + "$dir_with_frame_shift. Cannot combine features." + exit 1; + fi + dir_with_frame_shift=$dir + fi +done + +# W.r.t. utt2uniq file the script has different behavior compared to other files +# it is not compulsary for it to exist in src directories, but if it exists in +# even one it should exist in all. We will create the files where necessary +has_utt2uniq=false +for in_dir in $*; do + if [ -f $in_dir/utt2uniq ]; then + has_utt2uniq=true + break + fi +done + +if $has_utt2uniq; then + # we are going to create an utt2uniq file in the destdir + for in_dir in $*; do + if [ ! -f $in_dir/utt2uniq ]; then + # we assume that utt2uniq is a one to one mapping + cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}' + else + cat $in_dir/utt2uniq + fi + done | sort -k1 > $dest/utt2uniq + echo "$0: combined utt2uniq" +else + echo "$0 [info]: not combining utt2uniq as it does not exist" +fi +# some of the old scripts might provide utt2uniq as an extrafile, so just remove it +extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g") + +# segments are treated similarly to utt2uniq. If it exists in some, but not all +# src directories, then we generate segments where necessary. +has_segments=false +for in_dir in $*; do + if [ -f $in_dir/segments ]; then + has_segments=true + break + fi +done + +if $has_segments; then + for in_dir in $*; do + if [ ! -f $in_dir/segments ]; then + echo "$0 [info]: will generate missing segments for $in_dir" 1>&2 + utils/data/get_segments_for_data.sh $in_dir + else + cat $in_dir/segments + fi + done | sort -k1 > $dest/segments + echo "$0: combined segments" +else + echo "$0 [info]: not combining segments as it does not exist" +fi + +for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do + exists_somewhere=false + absent_somewhere=false + for d in $*; do + if [ -f $d/$file ]; then + exists_somewhere=true + else + absent_somewhere=true + fi + done + + if ! $absent_somewhere; then + set -o pipefail + ( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1; + set +o pipefail + echo "$0: combined $file" + else + if ! $exists_somewhere; then + echo "$0 [info]: not combining $file as it does not exist" + else + echo "$0 [info]: **not combining $file as it does not exist everywhere**" + fi + fi +done + +tools/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt + +if [[ $dir_with_frame_shift ]]; then + cp $dir_with_frame_shift/frame_shift $dest +fi + +if ! $skip_fix ; then + tools/fix_data_dir.sh $dest || exit 1; +fi + +exit 0 \ No newline at end of file diff --git a/utils/parse_options.sh b/utils/parse_options.sh new file mode 100644 index 000000000..f7151668f --- /dev/null +++ b/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### No we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. \ No newline at end of file diff --git a/utils/spm_decode b/utils/spm_decode new file mode 100755 index 000000000..a94aa2a8d --- /dev/null +++ b/utils/spm_decode @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# https://github.com/pytorch/fairseq/blob/master/LICENSE + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import sys + +import sentencepiece as spm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, + help="sentencepiece model to use for decoding") + parser.add_argument("--input", default=None, help="input file to decode") + parser.add_argument("--input_format", choices=["piece", "id"], default="piece") + args = parser.parse_args() + + sp = spm.SentencePieceProcessor() + sp.Load(args.model) + + if args.input_format == "piece": + def decode(l): + return "".join(sp.DecodePieces(l)) + elif args.input_format == "id": + def decode(l): + return "".join(sp.DecodeIds(l)) + else: + raise NotImplementedError + + def tok2int(tok): + # remap reference-side (represented as <>) to 0 + return int(tok) if tok != "<>" else 0 + + if args.input is None: + h = sys.stdin + else: + h = open(args.input, "r", encoding="utf-8") + for line in h: + print(decode(line.split())) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/utils/spm_encode b/utils/spm_encode new file mode 100755 index 000000000..081e40e96 --- /dev/null +++ b/utils/spm_encode @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in +# https://github.com/pytorch/fairseq/blob/master/LICENSE + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import contextlib +import sys + +import sentencepiece as spm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, + help="sentencepiece model to use for encoding") + parser.add_argument("--inputs", nargs="+", default=['-'], + help="input files to filter/encode") + parser.add_argument("--outputs", nargs="+", default=['-'], + help="path to save encoded outputs") + parser.add_argument("--output_format", choices=["piece", "id"], default="piece") + parser.add_argument("--min-len", type=int, metavar="N", + help="filter sentence pairs with fewer than N tokens") + parser.add_argument("--max-len", type=int, metavar="N", + help="filter sentence pairs with more than N tokens") + args = parser.parse_args() + + assert len(args.inputs) == len(args.outputs), \ + "number of input and output paths should match" + + sp = spm.SentencePieceProcessor() + sp.Load(args.model) + + if args.output_format == "piece": + def encode(l): + return sp.EncodeAsPieces(l) + elif args.output_format == "id": + def encode(l): + return list(map(str, sp.EncodeAsIds(l))) + else: + raise NotImplementedError + + if args.min_len is not None or args.max_len is not None: + def valid(line): + return ( + (args.min_len is None or len(line) >= args.min_len) and + (args.max_len is None or len(line) <= args.max_len) + ) + else: + def valid(lines): + return True + + with contextlib.ExitStack() as stack: + inputs = [ + stack.enter_context(open(input, "r", encoding="utf-8")) + if input != "-" else sys.stdin + for input in args.inputs + ] + outputs = [ + stack.enter_context(open(output, "w", encoding="utf-8")) + if output != "-" else sys.stdout + for output in args.outputs + ] + + stats = { + "num_empty": 0, + "num_filtered": 0, + } + + def encode_line(line): + line = line.strip() + if len(line) > 0: + line = encode(line) + if valid(line): + return line + else: + stats["num_filtered"] += 1 + else: + stats["num_empty"] += 1 + return None + + for i, lines in enumerate(zip(*inputs), start=1): + enc_lines = list(map(encode_line, lines)) + if not any(enc_line is None for enc_line in enc_lines): + for enc_line, output_h in zip(enc_lines, outputs): + print(" ".join(enc_line), file=output_h) + if i % 10000 == 0: + print("processed {} lines".format(i), file=sys.stderr) + + print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr) + print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/utils/spm_train b/utils/spm_train new file mode 100755 index 000000000..44330a55f --- /dev/null +++ b/utils/spm_train @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# https://github.com/pytorch/fairseq/blob/master/LICENSE +import sys + +import sentencepiece as spm + + +if __name__ == "__main__": + spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:])) \ No newline at end of file diff --git a/utils/sym2int.pl b/utils/sym2int.pl new file mode 100644 index 000000000..642f41bf7 --- /dev/null +++ b/utils/sym2int.pl @@ -0,0 +1,104 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +$ignore_oov = 0; + +for($x = 0; $x < 2; $x++) { + if ($ARGV[0] eq "--map-oov") { + shift @ARGV; + $map_oov = shift @ARGV; + if ($map_oov eq "-f" || $map_oov =~ m/words\.txt$/ || $map_oov eq "") { + # disallow '-f', the empty string and anything ending in words.txt as the + # OOV symbol because these are likely command-line errors. + die "the --map-oov option requires an argument"; + } + } + if ($ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } + } +} + +$symtab = shift @ARGV; +if (!defined $symtab) { + print STDERR "Usage: sym2int.pl [options] symtab [input transcriptions] > output transcriptions\n" . + "options: [--map-oov ] [-f ]\n" . + "note: can look like 4-5, or 4-, or 5-, or 1.\n"; +} +open(F, "<$symtab") || die "Error opening symbol table file $symtab"; +while() { + @A = split(" ", $_); + @A == 2 || die "bad line in symbol table file: $_"; + $sym2int{$A[0]} = $A[1] + 0; +} + +if (defined $map_oov && $map_oov !~ m/^\d+$/) { # not numeric-> look it up + if (!defined $sym2int{$map_oov}) { die "OOV symbol $map_oov not defined."; } + $map_oov = $sym2int{$map_oov}; +} + +$num_warning = 0; +$max_warning = 20; + +while (<>) { + @A = split(" ", $_); + @B = (); + for ($n = 0; $n < @A; $n++) { + $a = $A[$n]; + if ( (!defined $field_begin || $n >= $field_begin) + && (!defined $field_end || $n <= $field_end)) { + $i = $sym2int{$a}; + if (!defined ($i)) { + if (defined $map_oov) { + if ($num_warning++ < $max_warning) { + print STDERR "sym2int.pl: replacing $a with $map_oov\n"; + if ($num_warning == $max_warning) { + print STDERR "sym2int.pl: not warning for OOVs any more times\n"; + } + } + $i = $map_oov; + } else { + $pos = $n+1; + die "sym2int.pl: undefined symbol $a (in position $pos)\n"; + } + } + $a = $i; + } + push @B, $a; + } + print join(" ", @B); + print "\n"; +} +if ($num_warning > 0) { + print STDERR "** Replaced $num_warning instances of OOVs with $map_oov\n"; +} + +exit(0); \ No newline at end of file