pull/578/head
Hui Zhang 5 years ago
parent 907c93392f
commit 9626e99ce4

@ -28,12 +28,11 @@ class TestDeepSpeech2Model(unittest.TestCase):
#(B, T, D)
audio = np.random.randn(self.batch_size, max_len, self.feat_dim)
audio_len = np.random.randint(
max_len, size=self.batch_size, dtype='int32')
audio_len = np.random.randint(max_len, size=self.batch_size)
audio_len[-1] = max_len
#(B, U)
text = np.array([[1, 2], [1, 2]], dtype='int32')
text_len = np.array([2] * self.batch_size, dtype='int32')
text = np.array([[1, 2], [1, 2]])
text_len = np.array([2] * self.batch_size)
self.audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len, dtype='int64')

@ -21,39 +21,30 @@ from deepspeech.models.u2 import U2ConformerModel
class TestU2Model(unittest.TestCase):
def setUp(self):
batch_size = 2
feat_dim = 161
max_len = 100
audio = np.random.randn(batch_size, feat_dim, max_len)
audio_len = np.random.randint(100, size=batch_size, dtype='int32')
audio_len[-1] = 100
text = np.array([[1, 2], [1, 2]], dtype='int32')
text_len = np.array([2] * batch_size, dtype='int32')
paddle.set_device('cpu')
self.batch_size = 2
self.feat_dim = 161
self.max_len = 64
#(B, T, D)
audio = np.random.randn(self.batch_size, self.max_len, self.feat_dim)
audio_len = np.random.randint(self.max_len, size=self.batch_size)
audio_len[-1] = self.max_len
#(B, U)
text = np.array([[1, 2], [1, 2]])
text_len = np.array([2] * self.batch_size)
self.audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len, dtype='int64')
self.text = paddle.to_tensor(text, dtype='int32')
self.text_len = paddle.to_tensor(text_len, dtype='int64')
print(audio.shape)
print(audio_len.shape)
print(text.shape)
print(text_len.shape)
print("-----------------")
def test_ds2_1(self):
model = DeepSpeech2Model(
feat_size=feat_dim,
dict_size=10,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=1024,
use_gru=False,
share_rnn_weights=False, )
logits, probs, logits_len = model(self.audio, self.audio_len, self.text,
self.text_len)
print('probs.shape', probs.shape)
print("-----------------")
def test_transformer(self):
model = U2TransformerModel()
def test_conformer(self):
model = U2ConformerModel()
if __name__ == '__main__':

@ -0,0 +1,146 @@
#!/bin/bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey). Apache 2.0.
# 2014 David Snyder
# This script combines the data from multiple source directories into
# a single destination directory.
# See http://kaldi-asr.org/doc/data_prep.html#data_prep_data for information
# about what these directories contain.
# Begin configuration section.
extra_files= # specify additional files in 'src-data-dir' to merge, ex. "file1 file2 ..."
skip_fix=false # skip the fix_data_dir.sh in the end
# End configuration section.
echo "$0 $@" # Print the command line for logging
if [ -f path.sh ]; then . ./path.sh; fi
if [ -f parse_options.sh ]; then . parse_options.sh || exit 1; fi
if [ $# -lt 2 ]; then
echo "Usage: combine_data.sh [--extra-files 'file1 file2'] <dest-data-dir> <src-data-dir1> <src-data-dir2> ..."
echo "Note, files that don't appear in all source dirs will not be combined,"
echo "with the exception of utt2uniq and segments, which are created where necessary."
exit 1
fi
dest=$1;
shift;
first_src=$1;
rm -r $dest 2>/dev/null
mkdir -p $dest;
export LC_ALL=C
for dir in $*; do
if [ ! -f $dir/utt2spk ]; then
echo "$0: no such file $dir/utt2spk"
exit 1;
fi
done
# Check that frame_shift are compatible, where present together with features.
dir_with_frame_shift=
for dir in $*; do
if [[ -f $dir/feats.scp && -f $dir/frame_shift ]]; then
if [[ $dir_with_frame_shift ]] &&
! cmp -s $dir_with_frame_shift/frame_shift $dir/frame_shift; then
echo "$0:error: different frame_shift in directories $dir and " \
"$dir_with_frame_shift. Cannot combine features."
exit 1;
fi
dir_with_frame_shift=$dir
fi
done
# W.r.t. utt2uniq file the script has different behavior compared to other files
# it is not compulsary for it to exist in src directories, but if it exists in
# even one it should exist in all. We will create the files where necessary
has_utt2uniq=false
for in_dir in $*; do
if [ -f $in_dir/utt2uniq ]; then
has_utt2uniq=true
break
fi
done
if $has_utt2uniq; then
# we are going to create an utt2uniq file in the destdir
for in_dir in $*; do
if [ ! -f $in_dir/utt2uniq ]; then
# we assume that utt2uniq is a one to one mapping
cat $in_dir/utt2spk | awk '{printf("%s %s\n", $1, $1);}'
else
cat $in_dir/utt2uniq
fi
done | sort -k1 > $dest/utt2uniq
echo "$0: combined utt2uniq"
else
echo "$0 [info]: not combining utt2uniq as it does not exist"
fi
# some of the old scripts might provide utt2uniq as an extrafile, so just remove it
extra_files=$(echo "$extra_files"|sed -e "s/utt2uniq//g")
# segments are treated similarly to utt2uniq. If it exists in some, but not all
# src directories, then we generate segments where necessary.
has_segments=false
for in_dir in $*; do
if [ -f $in_dir/segments ]; then
has_segments=true
break
fi
done
if $has_segments; then
for in_dir in $*; do
if [ ! -f $in_dir/segments ]; then
echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
utils/data/get_segments_for_data.sh $in_dir
else
cat $in_dir/segments
fi
done | sort -k1 > $dest/segments
echo "$0: combined segments"
else
echo "$0 [info]: not combining segments as it does not exist"
fi
for file in utt2spk utt2lang utt2dur utt2num_frames reco2dur feats.scp text cmvn.scp vad.scp reco2file_and_channel wav.scp spk2gender $extra_files; do
exists_somewhere=false
absent_somewhere=false
for d in $*; do
if [ -f $d/$file ]; then
exists_somewhere=true
else
absent_somewhere=true
fi
done
if ! $absent_somewhere; then
set -o pipefail
( for f in $*; do cat $f/$file; done ) | sort -k1 > $dest/$file || exit 1;
set +o pipefail
echo "$0: combined $file"
else
if ! $exists_somewhere; then
echo "$0 [info]: not combining $file as it does not exist"
else
echo "$0 [info]: **not combining $file as it does not exist everywhere**"
fi
fi
done
tools/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
if [[ $dir_with_frame_shift ]]; then
cp $dir_with_frame_shift/frame_shift $dest
fi
if ! $skip_fix ; then
tools/fix_data_dir.sh $dest || exit 1;
fi
exit 0

@ -0,0 +1,97 @@
#!/bin/bash
# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
# Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
if [ "${!argpos}" == "--config" ]; then
argpos_plus1=$((argpos+1))
config=${!argpos_plus1}
[ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
. $config # source the config file.
fi
done
###
### No we process the command line options
###
while true; do
[ -z "${1:-}" ] && break; # break if there are no arguments
case "$1" in
# If the enclosing script is called with --help option, print the help
# message and exit. Scripts should put help messages in $help_message
--help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
else printf "$help_message\n" 1>&2 ; fi;
exit 0 ;;
--*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
exit 1 ;;
# If the first command-line argument begins with "--" (e.g. --foo-bar),
# then work out the variable name as $name, which will equal "foo_bar".
--*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
# Next we test whether the variable in question is undefned-- if so it's
# an invalid option and we die. Note: $0 evaluates to the name of the
# enclosing script.
# The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
# is undefined. We then have to wrap this test inside "eval" because
# foo_bar is itself inside a variable ($name).
eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
oldval="`eval echo \\$$name`";
# Work out whether we seem to be expecting a Boolean argument.
if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
was_bool=true;
else
was_bool=false;
fi
# Set the variable to the right value-- the escaped quotes make it work if
# the option had spaces, like --cmd "queue.pl -sync y"
eval $name=\"$2\";
# Check that Boolean-valued arguments are really Boolean.
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
exit 1;
fi
shift 2;
;;
*) break;
esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.

@ -0,0 +1,49 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# https://github.com/pytorch/fairseq/blob/master/LICENSE
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import sys
import sentencepiece as spm
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True,
help="sentencepiece model to use for decoding")
parser.add_argument("--input", default=None, help="input file to decode")
parser.add_argument("--input_format", choices=["piece", "id"], default="piece")
args = parser.parse_args()
sp = spm.SentencePieceProcessor()
sp.Load(args.model)
if args.input_format == "piece":
def decode(l):
return "".join(sp.DecodePieces(l))
elif args.input_format == "id":
def decode(l):
return "".join(sp.DecodeIds(l))
else:
raise NotImplementedError
def tok2int(tok):
# remap reference-side <unk> (represented as <<unk>>) to 0
return int(tok) if tok != "<<unk>>" else 0
if args.input is None:
h = sys.stdin
else:
h = open(args.input, "r", encoding="utf-8")
for line in h:
print(decode(line.split()))
if __name__ == "__main__":
main()

@ -0,0 +1,99 @@
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in
# https://github.com/pytorch/fairseq/blob/master/LICENSE
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import contextlib
import sys
import sentencepiece as spm
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True,
help="sentencepiece model to use for encoding")
parser.add_argument("--inputs", nargs="+", default=['-'],
help="input files to filter/encode")
parser.add_argument("--outputs", nargs="+", default=['-'],
help="path to save encoded outputs")
parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
parser.add_argument("--min-len", type=int, metavar="N",
help="filter sentence pairs with fewer than N tokens")
parser.add_argument("--max-len", type=int, metavar="N",
help="filter sentence pairs with more than N tokens")
args = parser.parse_args()
assert len(args.inputs) == len(args.outputs), \
"number of input and output paths should match"
sp = spm.SentencePieceProcessor()
sp.Load(args.model)
if args.output_format == "piece":
def encode(l):
return sp.EncodeAsPieces(l)
elif args.output_format == "id":
def encode(l):
return list(map(str, sp.EncodeAsIds(l)))
else:
raise NotImplementedError
if args.min_len is not None or args.max_len is not None:
def valid(line):
return (
(args.min_len is None or len(line) >= args.min_len) and
(args.max_len is None or len(line) <= args.max_len)
)
else:
def valid(lines):
return True
with contextlib.ExitStack() as stack:
inputs = [
stack.enter_context(open(input, "r", encoding="utf-8"))
if input != "-" else sys.stdin
for input in args.inputs
]
outputs = [
stack.enter_context(open(output, "w", encoding="utf-8"))
if output != "-" else sys.stdout
for output in args.outputs
]
stats = {
"num_empty": 0,
"num_filtered": 0,
}
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
for i, lines in enumerate(zip(*inputs), start=1):
enc_lines = list(map(encode_line, lines))
if not any(enc_line is None for enc_line in enc_lines):
for enc_line, output_h in zip(enc_lines, outputs):
print(" ".join(enc_line), file=output_h)
if i % 10000 == 0:
print("processed {} lines".format(i), file=sys.stderr)
print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
if __name__ == "__main__":
main()

@ -0,0 +1,13 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# https://github.com/pytorch/fairseq/blob/master/LICENSE
import sys
import sentencepiece as spm
if __name__ == "__main__":
spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))

@ -0,0 +1,104 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation Johns Hopkins University (Author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
$ignore_oov = 0;
for($x = 0; $x < 2; $x++) {
if ($ARGV[0] eq "--map-oov") {
shift @ARGV;
$map_oov = shift @ARGV;
if ($map_oov eq "-f" || $map_oov =~ m/words\.txt$/ || $map_oov eq "") {
# disallow '-f', the empty string and anything ending in words.txt as the
# OOV symbol because these are likely command-line errors.
die "the --map-oov option requires an argument";
}
}
if ($ARGV[0] eq "-f") {
shift @ARGV;
$field_spec = shift @ARGV;
if ($field_spec =~ m/^\d+$/) {
$field_begin = $field_spec - 1; $field_end = $field_spec - 1;
}
if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesty (properly, 1-10)
if ($1 ne "") {
$field_begin = $1 - 1; # Change to zero-based indexing.
}
if ($2 ne "") {
$field_end = $2 - 1; # Change to zero-based indexing.
}
}
if (!defined $field_begin && !defined $field_end) {
die "Bad argument to -f option: $field_spec";
}
}
}
$symtab = shift @ARGV;
if (!defined $symtab) {
print STDERR "Usage: sym2int.pl [options] symtab [input transcriptions] > output transcriptions\n" .
"options: [--map-oov <oov-symbol> ] [-f <field-range> ]\n" .
"note: <field-range> can look like 4-5, or 4-, or 5-, or 1.\n";
}
open(F, "<$symtab") || die "Error opening symbol table file $symtab";
while(<F>) {
@A = split(" ", $_);
@A == 2 || die "bad line in symbol table file: $_";
$sym2int{$A[0]} = $A[1] + 0;
}
if (defined $map_oov && $map_oov !~ m/^\d+$/) { # not numeric-> look it up
if (!defined $sym2int{$map_oov}) { die "OOV symbol $map_oov not defined."; }
$map_oov = $sym2int{$map_oov};
}
$num_warning = 0;
$max_warning = 20;
while (<>) {
@A = split(" ", $_);
@B = ();
for ($n = 0; $n < @A; $n++) {
$a = $A[$n];
if ( (!defined $field_begin || $n >= $field_begin)
&& (!defined $field_end || $n <= $field_end)) {
$i = $sym2int{$a};
if (!defined ($i)) {
if (defined $map_oov) {
if ($num_warning++ < $max_warning) {
print STDERR "sym2int.pl: replacing $a with $map_oov\n";
if ($num_warning == $max_warning) {
print STDERR "sym2int.pl: not warning for OOVs any more times\n";
}
}
$i = $map_oov;
} else {
$pos = $n+1;
die "sym2int.pl: undefined symbol $a (in position $pos)\n";
}
}
$a = $i;
}
push @B, $a;
}
print join(" ", @B);
print "\n";
}
if ($num_warning > 0) {
print STDERR "** Replaced $num_warning instances of OOVs with $map_oov\n";
}
exit(0);
Loading…
Cancel
Save