You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
133 lines
4.3 KiB
133 lines
4.3 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# Caculating the PPL of LM model
|
|
import os
|
|
|
|
import numpy as np
|
|
import paddle
|
|
from paddle.io import DataLoader
|
|
from yacs.config import CfgNode
|
|
|
|
from paddlespeech.s2t.models.lm.dataset import TextCollatorSpm
|
|
from paddlespeech.s2t.models.lm.dataset import TextDataset
|
|
from paddlespeech.s2t.models.lm_interface import dynamic_import_lm
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
def get_config(config_path):
|
|
confs = CfgNode(new_allowed=True)
|
|
confs.merge_from_file(config_path)
|
|
return confs
|
|
|
|
|
|
def load_trained_lm(args):
|
|
lm_config = get_config(args.rnnlm_conf)
|
|
lm_model_module = lm_config.model_module
|
|
lm_class = dynamic_import_lm(lm_model_module)
|
|
lm = lm_class(**lm_config.model)
|
|
model_dict = paddle.load(args.rnnlm)
|
|
lm.set_state_dict(model_dict)
|
|
return lm, lm_config
|
|
|
|
|
|
def write_dict_into_file(ppl_dict, name):
|
|
with open(name, "w") as f:
|
|
for key in ppl_dict.keys():
|
|
f.write(key + " " + ppl_dict[key] + "\n")
|
|
return
|
|
|
|
|
|
def cacu_perplexity(
|
|
lm_model,
|
|
lm_config,
|
|
args,
|
|
log_base=None, ):
|
|
unit_type = lm_config.data.unit_type
|
|
batch_size = lm_config.decoding.batch_size
|
|
num_workers = lm_config.decoding.num_workers
|
|
text_file_path = args.text_path
|
|
|
|
total_nll = 0.0
|
|
total_ntokens = 0
|
|
ppl_dict = {}
|
|
len_dict = {}
|
|
text_dataset = TextDataset.from_file(text_file_path)
|
|
collate_fn_text = TextCollatorSpm(
|
|
unit_type=unit_type,
|
|
vocab_filepath=args.vocab_path,
|
|
spm_model_prefix=args.bpeprefix)
|
|
train_loader = DataLoader(
|
|
text_dataset,
|
|
batch_size=batch_size,
|
|
collate_fn=collate_fn_text,
|
|
num_workers=num_workers)
|
|
|
|
logger.info("start caculating PPL......")
|
|
for i, (keys, ys_input_pad, ys_output_pad,
|
|
y_lens) in enumerate(train_loader()):
|
|
|
|
ys_input_pad = paddle.to_tensor(ys_input_pad)
|
|
ys_output_pad = paddle.to_tensor(ys_output_pad)
|
|
_, unused_logp, unused_count, nll, nll_count = lm_model.forward(
|
|
ys_input_pad, ys_output_pad)
|
|
nll = nll.numpy()
|
|
nll_count = nll_count.numpy()
|
|
for key, _nll, ntoken in zip(keys, nll, nll_count):
|
|
if log_base is None:
|
|
utt_ppl = np.exp(_nll / ntoken)
|
|
else:
|
|
utt_ppl = log_base**(_nll / ntoken / np.log(log_base))
|
|
|
|
# Write PPL of each utts for debugging or analysis
|
|
ppl_dict[key] = str(utt_ppl)
|
|
len_dict[key] = str(ntoken)
|
|
|
|
total_nll += nll.sum()
|
|
total_ntokens += nll_count.sum()
|
|
logger.info("Current total nll: " + str(total_nll))
|
|
logger.info("Current total tokens: " + str(total_ntokens))
|
|
write_dict_into_file(ppl_dict, os.path.join(args.output_dir, "uttPPL"))
|
|
write_dict_into_file(len_dict, os.path.join(args.output_dir, "uttLEN"))
|
|
if log_base is None:
|
|
ppl = np.exp(total_nll / total_ntokens)
|
|
else:
|
|
ppl = log_base**(total_nll / total_ntokens / np.log(log_base))
|
|
|
|
if log_base is None:
|
|
log_base = np.e
|
|
else:
|
|
log_base = log_base
|
|
|
|
return ppl, log_base
|
|
|
|
|
|
def run_get_perplexity(args):
|
|
if args.ngpu > 1:
|
|
raise NotImplementedError("only single GPU decoding is supported")
|
|
if args.ngpu == 1:
|
|
device = "gpu:0"
|
|
else:
|
|
device = "cpu"
|
|
paddle.set_device(device)
|
|
dtype = getattr(paddle, args.dtype)
|
|
logger.info(f"Decoding device={device}, dtype={dtype}")
|
|
lm_model, lm_config = load_trained_lm(args)
|
|
lm_model.to(device=device, dtype=dtype)
|
|
lm_model.eval()
|
|
PPL, log_base = cacu_perplexity(lm_model, lm_config, args, None)
|
|
logger.info("Final PPL: " + str(PPL))
|
|
logger.info("The log base is:" + str("%.2f" % log_base))
|