You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/s2t/exps/u2/bin/test_wav.py

152 lines
5.2 KiB

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for U2 model."""
import os
import sys
from pathlib import Path
import distutils
import numpy as np
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
# TODO(hui zhang): dynamic load
class U2Infer():
def __init__(self, config, args):
self.args = args
self.config = config
self.audio_file = args.audio_file
self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False}
self.preprocessing = Transformation(self.preprocess_conf)
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath,
spm_model_prefix=config.spm_model_prefix)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model
model_conf = config
with UpdateConfig(model_conf):
model_conf.input_dim = config.feat_dim
model_conf.output_dim = self.text_feature.vocab_size
model = U2Model.from_config(model_conf)
self.model = model
self.model.eval()
# load model
params_path = self.args.checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict)
def run(self):
check(args.audio_file)
with paddle.no_grad():
# read
audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
# fbank
feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}")
if self.args.debug:
np.savetxt("feat.transform.txt", feat)
ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode
logger.info(f"decode cfg: {decode_config}")
reverse_weight = getattr(decode_config, 'reverse_weight', 0.0)
result_transcripts = self.model.decode(
xs,
ilen,
text_feature=self.text_feature,
decoding_method=decode_config.decoding_method,
beam_size=decode_config.beam_size,
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=reverse_weight)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {result_transcripts[0][0]}")
return rsl
def check(audio_file):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(audio_file)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the wav file, please check the audio file format")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
assert (sample_rate == 16000)
logger.info("The audio file format is right")
def main(config, args):
U2Infer(config, args).run()
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
parser.add_argument(
"--debug",
type=distutils.util.strtobool,
default=False,
help="for debug.")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
main(config, args)