revise the asr infer.py

pull/1048/head
huangyuxin 3 years ago
parent 1707244472
commit 3fadcde5e2

@ -18,17 +18,17 @@ from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
import librosa
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
from ..utils import download_and_decompress from ..utils import download_and_decompress
from ..utils import logger from ..utils import logger
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from paddlespeech.s2t.exps.u2.config import get_cfg_defaults
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
@ -36,7 +36,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
pretrained_models = { pretrained_models = {
"wenetspeech_zh": { "wenetspeech_zh_16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz',
'md5': 'md5':
@ -73,7 +73,15 @@ class ASRExecutor(BaseExecutor):
default='wenetspeech', default='wenetspeech',
help='Choose model type of asr task.') help='Choose model type of asr task.')
self.parser.add_argument( self.parser.add_argument(
'--lang', type=str, default='zh', help='Choose model language.') '--lang',
type=str,
default='zh',
help='Choose model language. zh or en')
self.parser.add_argument(
"--model_sample_rate",
type=int,
default=16000,
help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
type=str, type=str,
@ -109,13 +117,15 @@ class ASRExecutor(BaseExecutor):
def _init_from_path(self, def _init_from_path(self,
model_type: str='wenetspeech', model_type: str='wenetspeech',
lang: str='zh', lang: str='zh',
model_sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None): ckpt_path: Optional[os.PathLike]=None):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
tag = model_type + '_' + lang model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + model_sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
@ -136,23 +146,24 @@ class ASRExecutor(BaseExecutor):
#Init body. #Init body.
parser_args = self.parser_args parser_args = self.parser_args
paddle.set_device(parser_args.device) paddle.set_device(parser_args.device)
self.config = get_cfg_defaults() self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
self.config.decoding.decoding_method = "attention_rescoring" self.config.decoding.decoding_method = "attention_rescoring"
#self.config.freeze()
model_conf = self.config.model model_conf = self.config.model
logger.info(model_conf) logger.info(model_conf)
with UpdateConfig(model_conf): with UpdateConfig(model_conf):
if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline":
from paddlespeech.s2t.io.collator import SpeechCollator
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.cmvn_path) res_path, self.config.collator.cmvn_path)
self.collate_fn_test = SpeechCollator.from_config(self.config) self.collate_fn_test = SpeechCollator.from_config(self.config)
model_conf.feat_size = self.collate_fn_test.feature_size model_conf.input_dim = self.collate_fn_test.feature_size
model_conf.dict_size = self.text_feature.vocab_size model_conf.output_dim = self.text_feature.vocab_size
elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech":
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
@ -163,6 +174,7 @@ class ASRExecutor(BaseExecutor):
model_conf.output_dim = self.text_feature.vocab_size model_conf.output_dim = self.text_feature.vocab_size
else: else:
raise Exception("wrong type") raise Exception("wrong type")
self.config.freeze()
model_class = dynamic_import(parser_args.model, model_alias) model_class = dynamic_import(parser_args.model, model_alias)
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
self.model = model self.model = model
@ -182,13 +194,13 @@ class ASRExecutor(BaseExecutor):
parser_args = self.parser_args parser_args = self.parser_args
config = self.config config = self.config
audio_file = input audio_file = input
logger.info("audio_file" + audio_file) logger.info("Preprocess audio_file:" + audio_file)
self.sr = config.collator.target_sample_rate self.sr = config.collator.target_sample_rate
# Get the object for feature extraction # Get the object for feature extraction
if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline":
audio, _ = collate_fn_test.process_utterance( audio, _ = self.collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ") audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0] audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32') audio = paddle.to_tensor(audio, dtype='float32')
@ -203,18 +215,30 @@ class ASRExecutor(BaseExecutor):
os.path.dirname(os.path.abspath(self.cfg_path)), os.path.dirname(os.path.abspath(self.cfg_path)),
"preprocess.yaml") "preprocess.yaml")
cmvn_path: data / mean_std.json
logger.info(preprocess_conf) logger.info(preprocess_conf)
preprocess_args = {"train": False} preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf) preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file")
audio, sample_rate = soundfile.read( audio, sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True) audio_file, dtype="int16", always_2d=True)
if self.change_format:
if audio.shape[1] >= 2:
audio = audio.mean(axis=1)
else:
audio = audio[:, 0]
audio = audio.astype("float32")
audio = librosa.resample(audio, sample_rate,
self.target_sample_rate)
sample_rate = self.target_sample_rate
audio = audio.astype("int16")
else:
audio = audio[:, 0]
if sample_rate != self.sr: if sample_rate != self.sr:
logger.error( logger.error(
f"sample rate error: {sample_rate}, need {self.sr} ") f"sample rate error: {sample_rate}, need {self.sr} ")
sys.exit(-1) sys.exit(-1)
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}") logger.info(f"audio shape: {audio.shape}")
# fbank # fbank
audio = preprocessing(audio, **preprocess_args) audio = preprocessing(audio, **preprocess_args)
@ -282,6 +306,63 @@ class ASRExecutor(BaseExecutor):
""" """
return self.result_transcripts return self.result_transcripts
def _check(self, audio_file: str, model_sample_rate: int):
self.target_sample_rate = model_sample_rate
if self.target_sample_rate != 16000 and self.target_sample_rate != 8000:
logger.error(
"please input --model_sample_rate 8000 or --model_sample_rate 16000")
raise Exception("invalid sample rate")
sys.exit(-1)
if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the audio file, please check the audio file format is 'wav'. \n \
you can try to use sox to change the file format.\n \
For example: \n \
sample rate: 16k \n \
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
if sample_rate != self.target_sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16bit 1 channel wav file. \
".format(self.target_sample_rate, self.target_sample_rate))
while (True):
logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip() == "Yes":
logger.info(
"change the sampele rate, channel to 16k and 1 channel")
break
elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip() == "No":
logger.info("Exit the program")
exit(1)
else:
logger.warning("Not regular input, please input again")
self.change_format = True
else:
logger.info("The audio file format is right")
self.change_format = False
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
""" """
Command line entry. Command line entry.
@ -290,24 +371,28 @@ class ASRExecutor(BaseExecutor):
model = self.parser_args.model model = self.parser_args.model
lang = self.parser_args.lang lang = self.parser_args.lang
model_sample_rate = self.parser_args.model_sample_rate
config = self.parser_args.config config = self.parser_args.config
ckpt_path = self.parser_args.ckpt_path ckpt_path = self.parser_args.ckpt_path
audio_file = os.path.abspath(self.parser_args.input) audio_file = os.path.abspath(self.parser_args.input)
device = self.parser_args.device device = self.parser_args.device
try: try:
res = self(model, lang, config, ckpt_path, audio_file, device) res = self(model, lang, model_sample_rate, config, ckpt_path, audio_file,
device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
except Exception as e: except Exception as e:
print(e) print(e)
return False return False
def __call__(self, model, lang, config, ckpt_path, audio_file, device): def __call__(self, model, lang, model_sample_rate, config, ckpt_path, audio_file,
device):
""" """
Python API to call an executor. Python API to call an executor.
""" """
self._init_from_path(model, lang, config, ckpt_path) self._check(audio_file, model_sample_rate)
self._init_from_path(model, lang, model_sample_rate, config, ckpt_path)
self.preprocess(audio_file) self.preprocess(audio_file)
self.infer() self.infer()
res = self.postprocess() # Retrieve result of asr. res = self.postprocess() # Retrieve result of asr.

Loading…
Cancel
Save