You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/s2t/exps/deepspeech2/bin/test_hub.py

202 lines
6.7 KiB

3 years ago
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
import os
import sys
from pathlib import Path
import paddle
import soundfile
3 years ago
from paddlespeech.s2t.exps.deepspeech2.config import get_cfg_defaults
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils import mp_tools
from paddlespeech.s2t.utils.checkpoint import Checkpoint
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import print_arguments
from paddlespeech.s2t.utils.utility import UpdateConfig
3 years ago
logger = Log(__name__).getlog()
class DeepSpeech2Tester_hub():
def __init__(self, config, args):
self.args = args
self.config = config
self.audio_file = args.audio_file
self.collate_fn_test = SpeechCollator.from_config(config)
self._text_featurizer = TextFeaturizer(
unit_type=config.collator.unit_type, vocab_filepath=None)
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
result_transcripts = self.model.decode(
audio,
audio_len,
vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
beam_beta=cfg.beta,
beam_size=cfg.beam_size,
cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch)
return result_transcripts
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
self.model.eval()
cfg = self.config
audio_file = self.audio_file
collate_fn_test = self.collate_fn_test
audio, _ = collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
vocab_list = collate_fn_test.vocab_list
result_transcripts = self.compute_result_transcripts(
audio, audio_len, vocab_list, cfg.decoding)
logger.info("result_transcripts: " + result_transcripts[0])
def run_test(self):
self.resume()
try:
self.test()
except KeyboardInterrupt:
exit(-1)
def setup(self):
"""Setup the experiment.
"""
3 years ago
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
3 years ago
self.setup_output_dir()
self.setup_checkpointer()
self.setup_model()
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
if self.args.output:
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
else:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
def setup_model(self):
config = self.config.clone()
with UpdateConfig(config):
config.model.feat_size = self.collate_fn_test.feature_size
config.model.dict_size = self.collate_fn_test.vocab_size
if self.args.model_type == 'offline':
model = DeepSpeech2Model.from_config(config.model)
elif self.args.model_type == 'online':
model = DeepSpeech2ModelOnline.from_config(config.model)
else:
raise Exception("wrong model type")
self.model = model
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
self.checkpoint_dir = checkpoint_dir
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
def resume(self):
"""Resume from the checkpoint at checkpoints in the output
directory or load a specified checkpoint.
"""
params_path = self.args.checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict)
def check(audio_file):
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(audio_file)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the wav file, please check the audio file format")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
assert (sample_rate == 16000)
logger.info("The audio file format is right")
3 years ago
def main_sp(config, args):
exp = DeepSpeech2Tester_hub(config, args)
exp.setup()
exp.run_test()
3 years ago
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument(
"--model_type", type=str, default='offline', help='offline/online')
parser.add_argument("--audio_file", type=str, help='audio file path')
3 years ago
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
if not os.path.isfile(args.audio_file):
print("Please input the audio file path")
sys.exit(-1)
check(args.audio_file)
3 years ago
print("model_type:{}".format(args.model_type))
# https://yaml.org/type/float.html
config = get_cfg_defaults(args.model_type)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
main(config, args)