Merge pull request #952 from Jackwaterveg/dev_transformerLM
Add the feature: caculating the perplexity of transformerLMpull/957/head
commit
980944dab1
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
|
||||
import configargparse
|
||||
|
||||
|
||||
def get_parser():
|
||||
"""Get default arguments."""
|
||||
parser = configargparse.ArgumentParser(
|
||||
description="The parser for caculating the perplexity of transformer language model ",
|
||||
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
||||
formatter_class=configargparse.ArgumentDefaultsHelpFormatter, )
|
||||
|
||||
parser.add_argument(
|
||||
"--rnnlm", type=str, default=None, help="RNNLM model file to read")
|
||||
|
||||
parser.add_argument(
|
||||
"--rnnlm-conf",
|
||||
type=str,
|
||||
default=None,
|
||||
help="RNNLM model config file to read")
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="vocab path to for token2id")
|
||||
|
||||
parser.add_argument(
|
||||
"--bpeprefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path of bpeprefix for loading")
|
||||
|
||||
parser.add_argument(
|
||||
"--text_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path of text file for testing ")
|
||||
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of gpu to use, 0 for using cpu instead")
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=("float16", "float32", "float64"),
|
||||
default="float32",
|
||||
help="Float precision (only available in --api v2)", )
|
||||
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=".",
|
||||
help="The output directory to store the sentence PPL")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(args)
|
||||
from deepspeech.exps.lm.transformer.lm_cacu_perplexity import run_get_perplexity
|
||||
run_get_perplexity(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:])
|
@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Caculating the PPL of LM model
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.io import DataLoader
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from deepspeech.io.collator import TextCollatorSpm
|
||||
from deepspeech.io.dataset import TextDataset
|
||||
from deepspeech.models.lm_interface import dynamic_import_lm
|
||||
from deepspeech.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
def get_config(config_path):
|
||||
confs = CfgNode(new_allowed=True)
|
||||
confs.merge_from_file(config_path)
|
||||
return confs
|
||||
|
||||
|
||||
def load_trained_lm(args):
|
||||
lm_config = get_config(args.rnnlm_conf)
|
||||
lm_model_module = lm_config.model_module
|
||||
lm_class = dynamic_import_lm(lm_model_module)
|
||||
lm = lm_class(**lm_config.model)
|
||||
model_dict = paddle.load(args.rnnlm)
|
||||
lm.set_state_dict(model_dict)
|
||||
return lm, lm_config
|
||||
|
||||
|
||||
def write_dict_into_file(ppl_dict, name):
|
||||
with open(name, "w") as f:
|
||||
for key in ppl_dict.keys():
|
||||
f.write(key + " " + ppl_dict[key] + "\n")
|
||||
return
|
||||
|
||||
|
||||
def cacu_perplexity(
|
||||
lm_model,
|
||||
lm_config,
|
||||
args,
|
||||
log_base=None, ):
|
||||
unit_type = lm_config.data.unit_type
|
||||
batch_size = lm_config.decoding.batch_size
|
||||
num_workers = lm_config.decoding.num_workers
|
||||
text_file_path = args.text_path
|
||||
|
||||
total_nll = 0.0
|
||||
total_ntokens = 0
|
||||
ppl_dict = {}
|
||||
len_dict = {}
|
||||
text_dataset = TextDataset.from_file(text_file_path)
|
||||
collate_fn_text = TextCollatorSpm(
|
||||
unit_type=unit_type,
|
||||
vocab_filepath=args.vocab_path,
|
||||
spm_model_prefix=args.bpeprefix)
|
||||
train_loader = DataLoader(
|
||||
text_dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=collate_fn_text,
|
||||
num_workers=num_workers)
|
||||
|
||||
logger.info("start caculating PPL......")
|
||||
for i, (keys, ys_input_pad, ys_output_pad,
|
||||
y_lens) in enumerate(train_loader()):
|
||||
|
||||
ys_input_pad = paddle.to_tensor(ys_input_pad)
|
||||
ys_output_pad = paddle.to_tensor(ys_output_pad)
|
||||
_, unused_logp, unused_count, nll, nll_count = lm_model.forward(
|
||||
ys_input_pad, ys_output_pad)
|
||||
nll = nll.numpy()
|
||||
nll_count = nll_count.numpy()
|
||||
for key, _nll, ntoken in zip(keys, nll, nll_count):
|
||||
if log_base is None:
|
||||
utt_ppl = np.exp(_nll / ntoken)
|
||||
else:
|
||||
utt_ppl = log_base**(_nll / ntoken / np.log(log_base))
|
||||
|
||||
# Write PPL of each utts for debugging or analysis
|
||||
ppl_dict[key] = str(utt_ppl)
|
||||
len_dict[key] = str(ntoken)
|
||||
|
||||
total_nll += nll.sum()
|
||||
total_ntokens += nll_count.sum()
|
||||
logger.info("Current total nll: " + str(total_nll))
|
||||
logger.info("Current total tokens: " + str(total_ntokens))
|
||||
write_dict_into_file(ppl_dict, os.path.join(args.output_dir, "uttPPL"))
|
||||
write_dict_into_file(len_dict, os.path.join(args.output_dir, "uttLEN"))
|
||||
if log_base is None:
|
||||
ppl = np.exp(total_nll / total_ntokens)
|
||||
else:
|
||||
ppl = log_base**(total_nll / total_ntokens / np.log(log_base))
|
||||
|
||||
if log_base is None:
|
||||
log_base = np.e
|
||||
else:
|
||||
log_base = log_base
|
||||
|
||||
return ppl, log_base
|
||||
|
||||
|
||||
def run_get_perplexity(args):
|
||||
if args.ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
if args.ngpu == 1:
|
||||
device = "gpu:0"
|
||||
else:
|
||||
device = "cpu"
|
||||
paddle.set_device(device)
|
||||
dtype = getattr(paddle, args.dtype)
|
||||
logger.info(f"Decoding device={device}, dtype={dtype}")
|
||||
lm_model, lm_config = load_trained_lm(args)
|
||||
lm_model.to(device=device, dtype=dtype)
|
||||
lm_model.eval()
|
||||
PPL, log_base = cacu_perplexity(lm_model, lm_config, args, None)
|
||||
logger.info("Final PPL: " + str(PPL))
|
||||
logger.info("The log base is:" + str("%.2f" % log_base))
|
@ -0,0 +1,53 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
expdir=exp
|
||||
datadir=data
|
||||
|
||||
ngpu=0
|
||||
|
||||
# lm params
|
||||
rnnlm_config_path=conf/lm/transformer.yaml
|
||||
lmexpdir=exp/lm/transformer
|
||||
lang_model=transformerLM.pdparams
|
||||
|
||||
#data path
|
||||
test_set=${datadir}/test_clean/text
|
||||
test_set_lower=${datadir}/test_clean/text_lower
|
||||
train_set=train_960
|
||||
|
||||
# bpemode (unigram or bpe)
|
||||
nbpe=5000
|
||||
bpemode=unigram
|
||||
bpeprefix=${datadir}/lang_char/${train_set}_${bpemode}${nbpe}
|
||||
bpemodel=${bpeprefix}.model
|
||||
|
||||
vocabfile=${bpeprefix}_units.txt
|
||||
vocabfile_lower=${bpeprefix}_units_lower.txt
|
||||
|
||||
output_dir=${expdir}/lm/transformer/perplexity
|
||||
|
||||
mkdir -p ${output_dir}
|
||||
|
||||
# Transform the data upper case to lower
|
||||
if [ -f ${vocabfile} ]; then
|
||||
tr A-Z a-z < ${vocabfile} > ${vocabfile_lower}
|
||||
fi
|
||||
|
||||
if [ -f ${test_set} ]; then
|
||||
tr A-Z a-z < ${test_set} > ${test_set_lower}
|
||||
fi
|
||||
|
||||
python ${LM_BIN_DIR}/cacu_perplexity.py \
|
||||
--rnnlm ${lmexpdir}/${lang_model} \
|
||||
--rnnlm-conf ${rnnlm_config_path} \
|
||||
--vocab_path ${vocabfile_lower} \
|
||||
--bpeprefix ${bpeprefix} \
|
||||
--text_path ${test_set_lower} \
|
||||
--output_dir ${output_dir} \
|
||||
--ngpu ${ngpu}
|
||||
|
Loading…
Reference in new issue