parent
fc8a7a152e
commit
d64f6e9ea5
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import configargparse
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
"""Get default arguments."""
|
||||||
|
parser = configargparse.ArgumentParser(
|
||||||
|
description="The parser for caculating the perplexity of transformer language model ",
|
||||||
|
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
||||||
|
formatter_class=configargparse.ArgumentDefaultsHelpFormatter, )
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnnlm", type=str, default=None, help="RNNLM model file to read")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--rnnlm-conf",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="RNNLM model config file to read")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocab_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="vocab path to for token2id")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpeprefix",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The path of bpeprefix for loading")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The path of text file for testing ")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The number of gpu to use, 0 for using cpu instead")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
choices=("float16", "float32", "float64"),
|
||||||
|
default="float32",
|
||||||
|
help="Float precision (only available in --api v2)", )
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="The output directory to store the sentence PPL")
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args(args)
|
||||||
|
from deepspeech.exps.lm.transformer.lm_cacu_perplexity import run_get_perplexity
|
||||||
|
run_get_perplexity(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(sys.argv[1:])
|
@ -0,0 +1,132 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# Caculating the PPL of LM model
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from deepspeech.io.collator import TextCollatorSpm
|
||||||
|
from deepspeech.io.dataset import TextDataset
|
||||||
|
from deepspeech.models.lm_interface import dynamic_import_lm
|
||||||
|
from deepspeech.utils.log import Log
|
||||||
|
|
||||||
|
logger = Log(__name__).getlog()
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(config_path):
|
||||||
|
confs = CfgNode(new_allowed=True)
|
||||||
|
confs.merge_from_file(config_path)
|
||||||
|
return confs
|
||||||
|
|
||||||
|
|
||||||
|
def load_trained_lm(args):
|
||||||
|
lm_config = get_config(args.rnnlm_conf)
|
||||||
|
lm_model_module = lm_config.model_module
|
||||||
|
lm_class = dynamic_import_lm(lm_model_module)
|
||||||
|
lm = lm_class(**lm_config.model)
|
||||||
|
model_dict = paddle.load(args.rnnlm)
|
||||||
|
lm.set_state_dict(model_dict)
|
||||||
|
return lm, lm_config
|
||||||
|
|
||||||
|
|
||||||
|
def write_dict_into_file(ppl_dict, name):
|
||||||
|
with open(name, "w") as f:
|
||||||
|
for key in ppl_dict.keys():
|
||||||
|
f.write(key + " " + ppl_dict[key] + "\n")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def cacu_perplexity(
|
||||||
|
lm_model,
|
||||||
|
lm_config,
|
||||||
|
args,
|
||||||
|
log_base=None, ):
|
||||||
|
unit_type = lm_config.data.unit_type
|
||||||
|
batch_size = lm_config.decoding.batch_size
|
||||||
|
num_workers = lm_config.decoding.num_workers
|
||||||
|
text_file_path = args.text_path
|
||||||
|
|
||||||
|
total_nll = 0.0
|
||||||
|
total_ntokens = 0
|
||||||
|
ppl_dict = {}
|
||||||
|
len_dict = {}
|
||||||
|
text_dataset = TextDataset.from_file(text_file_path)
|
||||||
|
collate_fn_text = TextCollatorSpm(
|
||||||
|
unit_type=unit_type,
|
||||||
|
vocab_filepath=args.vocab_path,
|
||||||
|
spm_model_prefix=args.bpeprefix)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
text_dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
collate_fn=collate_fn_text,
|
||||||
|
num_workers=num_workers)
|
||||||
|
|
||||||
|
logger.info("start caculating PPL......")
|
||||||
|
for i, (keys, ys_input_pad, ys_output_pad,
|
||||||
|
y_lens) in enumerate(train_loader()):
|
||||||
|
|
||||||
|
ys_input_pad = paddle.to_tensor(ys_input_pad)
|
||||||
|
ys_output_pad = paddle.to_tensor(ys_output_pad)
|
||||||
|
_, unused_logp, unused_count, nll, nll_count = lm_model.forward(
|
||||||
|
ys_input_pad, ys_output_pad)
|
||||||
|
nll = nll.numpy()
|
||||||
|
nll_count = nll_count.numpy()
|
||||||
|
for key, _nll, ntoken in zip(keys, nll, nll_count):
|
||||||
|
if log_base is None:
|
||||||
|
utt_ppl = np.exp(_nll / ntoken)
|
||||||
|
else:
|
||||||
|
utt_ppl = log_base**(_nll / ntoken / np.log(log_base))
|
||||||
|
|
||||||
|
# Write PPL of each utts for debugging or analysis
|
||||||
|
ppl_dict[key] = str(utt_ppl)
|
||||||
|
len_dict[key] = str(ntoken)
|
||||||
|
|
||||||
|
total_nll += nll.sum()
|
||||||
|
total_ntokens += nll_count.sum()
|
||||||
|
logger.info("Current total nll: " + str(total_nll))
|
||||||
|
logger.info("Current total tokens: " + str(total_ntokens))
|
||||||
|
write_dict_into_file(ppl_dict, os.path.join(args.output_dir, "uttPPL"))
|
||||||
|
write_dict_into_file(len_dict, os.path.join(args.output_dir, "uttLEN"))
|
||||||
|
if log_base is None:
|
||||||
|
ppl = np.exp(total_nll / total_ntokens)
|
||||||
|
else:
|
||||||
|
ppl = log_base**(total_nll / total_ntokens / np.log(log_base))
|
||||||
|
|
||||||
|
if log_base is None:
|
||||||
|
log_base = np.e
|
||||||
|
else:
|
||||||
|
log_base = log_base
|
||||||
|
|
||||||
|
return ppl, log_base
|
||||||
|
|
||||||
|
|
||||||
|
def run_get_perplexity(args):
|
||||||
|
if args.ngpu > 1:
|
||||||
|
raise NotImplementedError("only single GPU decoding is supported")
|
||||||
|
if args.ngpu == 1:
|
||||||
|
device = "gpu:0"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
paddle.set_device(device)
|
||||||
|
dtype = getattr(paddle, args.dtype)
|
||||||
|
logger.info(f"Decoding device={device}, dtype={dtype}")
|
||||||
|
lm_model, lm_config = load_trained_lm(args)
|
||||||
|
lm_model.to(device=device, dtype=dtype)
|
||||||
|
lm_model.eval()
|
||||||
|
PPL, log_base = cacu_perplexity(lm_model, lm_config, args, None)
|
||||||
|
logger.info("Final PPL: " + str(PPL))
|
||||||
|
logger.info("The log base is:" + str("%.2f" % log_base))
|
@ -0,0 +1,53 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
stage=-1
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
expdir=exp
|
||||||
|
datadir=data
|
||||||
|
|
||||||
|
ngpu=0
|
||||||
|
|
||||||
|
# lm params
|
||||||
|
rnnlm_config_path=conf/lm/transformer.yaml
|
||||||
|
lmexpdir=exp/lm/transformer
|
||||||
|
lang_model=transformerLM.pdparams
|
||||||
|
|
||||||
|
#data path
|
||||||
|
test_set=${datadir}/test_clean/text
|
||||||
|
test_set_lower=${datadir}/test_clean/text_lower
|
||||||
|
train_set=train_960
|
||||||
|
|
||||||
|
# bpemode (unigram or bpe)
|
||||||
|
nbpe=5000
|
||||||
|
bpemode=unigram
|
||||||
|
bpeprefix=${datadir}/lang_char/${train_set}_${bpemode}${nbpe}
|
||||||
|
bpemodel=${bpeprefix}.model
|
||||||
|
|
||||||
|
vocabfile=${bpeprefix}_units.txt
|
||||||
|
vocabfile_lower=${bpeprefix}_units_lower.txt
|
||||||
|
|
||||||
|
output_dir=${expdir}/lm/transformer/perplexity
|
||||||
|
|
||||||
|
mkdir -p ${output_dir}
|
||||||
|
|
||||||
|
# Transform the data upper case to lower
|
||||||
|
if [ -f ${vocabfile} ]; then
|
||||||
|
tr A-Z a-z < ${vocabfile} > ${vocabfile_lower}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -f ${test_set} ]; then
|
||||||
|
tr A-Z a-z < ${test_set} > ${test_set_lower}
|
||||||
|
fi
|
||||||
|
|
||||||
|
python ${LM_BIN_DIR}/cacu_perplexity.py \
|
||||||
|
--rnnlm ${lmexpdir}/${lang_model} \
|
||||||
|
--rnnlm-conf ${rnnlm_config_path} \
|
||||||
|
--vocab_path ${vocabfile_lower} \
|
||||||
|
--bpeprefix ${bpeprefix} \
|
||||||
|
--text_path ${test_set_lower} \
|
||||||
|
--output_dir ${output_dir} \
|
||||||
|
--ngpu ${ngpu}
|
||||||
|
|
Loading…
Reference in new issue