Merge pull request #2 from Jackwaterveg/cli_infer

LGTM
pull/1048/head
KP 3 years ago committed by GitHub
commit ba0dc3c1c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,17 +18,17 @@ from typing import List
from typing import Optional from typing import Optional
from typing import Union from typing import Union
import librosa
import paddle import paddle
import soundfile import soundfile
from yacs.config import CfgNode
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
from ..utils import download_and_decompress from ..utils import download_and_decompress
from ..utils import logger from ..utils import logger
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from paddlespeech.s2t.exps.u2.config import get_cfg_defaults
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
@ -36,7 +36,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor'] __all__ = ['ASRExecutor']
pretrained_models = { pretrained_models = {
"wenetspeech_zh": { "wenetspeech_zh_16k": {
'url': 'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz', 'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz',
'md5': 'md5':
@ -73,7 +73,15 @@ class ASRExecutor(BaseExecutor):
default='wenetspeech', default='wenetspeech',
help='Choose model type of asr task.') help='Choose model type of asr task.')
self.parser.add_argument( self.parser.add_argument(
'--lang', type=str, default='zh', help='Choose model language.') '--lang',
type=str,
default='zh',
help='Choose model language. zh or en')
self.parser.add_argument(
"--model_sample_rate",
type=int,
default=16000,
help='Choose the audio sample rate of the model. 8000 or 16000')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
type=str, type=str,
@ -109,13 +117,16 @@ class ASRExecutor(BaseExecutor):
def _init_from_path(self, def _init_from_path(self,
model_type: str='wenetspeech', model_type: str='wenetspeech',
lang: str='zh', lang: str='zh',
model_sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None): ckpt_path: Optional[os.PathLike]=None,
device: str='cpu'):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
if cfg_path is None or ckpt_path is None: if cfg_path is None or ckpt_path is None:
tag = model_type + '_' + lang model_sample_rate_str = '16k' if model_sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + model_sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
@ -130,40 +141,44 @@ class ASRExecutor(BaseExecutor):
res_path = os.path.dirname( res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
# Enter the path of model root
os.chdir(res_path)
#Init body. #Init body.
parser_args = self.parser_args paddle.set_device(device)
paddle.set_device(parser_args.device) self.config = CfgNode(new_allowed=True)
self.config = get_cfg_defaults()
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
self.config.decoding.decoding_method = "attention_rescoring" self.config.decoding.decoding_method = "attention_rescoring"
#self.config.freeze()
model_conf = self.config.model model_conf = self.config.model
logger.info(model_conf) logger.info(model_conf)
with UpdateConfig(model_conf): with UpdateConfig(model_conf):
if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": if model_type == "ds2_online" or model_type == "ds2_offline":
from paddlespeech.s2t.io.collator import SpeechCollator
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.config.collator.vocab_filepath = os.path.join( self.config.collator.mean_std_filepath = os.path.join(
res_path, self.config.collator.cmvn_path) res_path, self.config.collator.cmvn_path)
self.collate_fn_test = SpeechCollator.from_config(self.config) self.collate_fn_test = SpeechCollator.from_config(self.config)
model_conf.feat_size = self.collate_fn_test.feature_size text_feature = TextFeaturizer(
model_conf.dict_size = self.text_feature.vocab_size unit_type=self.config.collator.unit_type,
elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.collate_fn_test.feature_size
model_conf.output_dim = text_feature.vocab_size
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.text_feature = TextFeaturizer( text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.config.collator.feat_dim model_conf.input_dim = self.config.collator.feat_dim
model_conf.output_dim = self.text_feature.vocab_size model_conf.output_dim = text_feature.vocab_size
else: else:
raise Exception("wrong type") raise Exception("wrong type")
model_class = dynamic_import(parser_args.model, model_alias) self.config.freeze()
# Enter the path of model root
os.chdir(res_path)
model_class = dynamic_import(model_type, model_alias)
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
self.model = model self.model = model
self.model.eval() self.model.eval()
@ -173,75 +188,94 @@ class ASRExecutor(BaseExecutor):
model_dict = paddle.load(params_path) model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def preprocess(self, input: Union[str, os.PathLike]): def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
""" """
Input preprocess and return paddle.Tensor stored in self.input. Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet). Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
""" """
parser_args = self.parser_args
config = self.config
audio_file = input audio_file = input
logger.info("audio_file" + audio_file) logger.info("Preprocess audio_file:" + audio_file)
self.sr = config.collator.target_sample_rate config_target_sample_rate = self.config.collator.target_sample_rate
# Get the object for feature extraction # Get the object for feature extraction
if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline": if model_type == "ds2_online" or model_type == "ds2_offline":
audio, _ = collate_fn_test.process_utterance( audio, _ = self.collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ") audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0] audio_len = audio.shape[0]
audio = paddle.to_tensor(audio, dtype='float32') audio = paddle.to_tensor(audio, dtype='float32')
self.audio_len = paddle.to_tensor(audio_len) audio_len = paddle.to_tensor(audio_len)
self.audio = paddle.unsqueeze(audio, axis=0) audio = paddle.unsqueeze(audio, axis=0)
self.vocab_list = collate_fn_test.vocab_list vocab_list = collate_fn_test.vocab_list
logger.info(f"audio feat shape: {self.audio.shape}") self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": logger.info(f"audio feat shape: {audio.shape}")
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
logger.info("get the preprocess conf") logger.info("get the preprocess conf")
preprocess_conf = os.path.join( preprocess_conf = os.path.join(
os.path.dirname(os.path.abspath(self.cfg_path)), os.path.dirname(os.path.abspath(self.cfg_path)),
"preprocess.yaml") "preprocess.yaml")
cmvn_path: data / mean_std.json
logger.info(preprocess_conf) logger.info(preprocess_conf)
preprocess_args = {"train": False} preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf) preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file")
audio, sample_rate = soundfile.read( audio, sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True) audio_file, dtype="int16", always_2d=True)
if sample_rate != self.sr:
if self.change_format:
if audio.shape[1] >= 2:
audio = audio.mean(axis=1)
else:
audio = audio[:, 0]
audio = audio.astype("float32")
audio = librosa.resample(audio, sample_rate,
self.target_sample_rate)
sample_rate = self.target_sample_rate
audio = audio.astype("int16")
else:
audio = audio[:, 0]
if sample_rate != config_target_sample_rate:
logger.error( logger.error(
f"sample rate error: {sample_rate}, need {self.sr} ") f"sample rate error: {sample_rate}, need {self.sr} ")
sys.exit(-1) sys.exit(-1)
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}") logger.info(f"audio shape: {audio.shape}")
# fbank # fbank
audio = preprocessing(audio, **preprocess_args) audio = preprocessing(audio, **preprocess_args)
self.audio_len = paddle.to_tensor(audio.shape[0]) audio_len = paddle.to_tensor(audio.shape[0])
self.audio = paddle.to_tensor( audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
audio, dtype='float32').unsqueeze(axis=0) text_feature = TextFeaturizer(
logger.info(f"audio feat shape: {self.audio.shape}") unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
else: else:
raise Exception("wrong type") raise Exception("wrong type")
@paddle.no_grad() @paddle.no_grad()
def infer(self): def infer(self, model_type: str):
""" """
Model inference and result stored in self.output. Model inference and result stored in self.output.
""" """
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
cfg = self.config.decoding cfg = self.config.decoding
parser_args = self.parser_args audio = self._inputs["audio"]
audio = self.audio audio_len = self._inputs["audio_len"]
audio_len = self.audio_len if model_type == "ds2_online" or model_type == "ds2_offline":
if parser_args.model == "ds2_online" or parser_args.model == "ds2_offline":
vocab_list = self.vocab_list
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
vocab_list, text_feature.vocab_list,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path, lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha, beam_alpha=cfg.alpha,
@ -250,14 +284,13 @@ class ASRExecutor(BaseExecutor):
cutoff_prob=cfg.cutoff_prob, cutoff_prob=cfg.cutoff_prob,
cutoff_top_n=cfg.cutoff_top_n, cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch) num_processes=cfg.num_proc_bsearch)
self.result_transcripts = result_transcripts[0] self._outputs["result"] = result_transcripts[0]
elif parser_args.model == "conformer" or parser_args.model == "transformer" or parser_args.model == "wenetspeech": elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
text_feature = self.text_feature
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
audio, audio,
audio_len, audio_len,
text_feature=self.text_feature, text_feature=text_feature,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path, lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha, beam_alpha=cfg.alpha,
@ -270,46 +303,110 @@ class ASRExecutor(BaseExecutor):
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)
self.result_transcripts = result_transcripts[0][0] self._outputs["result"] = result_transcripts[0][0]
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")
pass
def postprocess(self) -> Union[str, os.PathLike]: def postprocess(self) -> Union[str, os.PathLike]:
""" """
Output postprocess and return human-readable results such as texts and audio files. Output postprocess and return human-readable results such as texts and audio files.
""" """
return self.result_transcripts return self._outputs["result"]
def _check(self, audio_file: str, model_sample_rate: int):
self.target_sample_rate = model_sample_rate
if self.target_sample_rate != 16000 and self.target_sample_rate != 8000:
logger.error(
"please input --model_sample_rate 8000 or --model_sample_rate 16000"
)
raise Exception("invalid sample rate")
sys.exit(-1)
if not os.path.isfile(audio_file):
logger.error("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the audio file, please check the audio file format is 'wav'. \n \
you can try to use sox to change the file format.\n \
For example: \n \
sample rate: 16k \n \
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
if sample_rate != self.target_sample_rate:
logger.warning(
"The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16bit 1 channel wav file. \
"
.format(self.target_sample_rate, self.target_sample_rate))
while (True):
logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip() == "Yes":
logger.info(
"change the sampele rate, channel to 16k and 1 channel")
break
elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip() == "No":
logger.info("Exit the program")
exit(1)
else:
logger.warning("Not regular input, please input again")
self.change_format = True
else:
logger.info("The audio file format is right")
self.change_format = False
def execute(self, argv: List[str]) -> bool: def execute(self, argv: List[str]) -> bool:
""" """
Command line entry. Command line entry.
""" """
self.parser_args = self.parser.parse_args(argv) parser_args = self.parser.parse_args(argv)
model = self.parser_args.model model = parser_args.model
lang = self.parser_args.lang lang = parser_args.lang
config = self.parser_args.config model_sample_rate = parser_args.model_sample_rate
ckpt_path = self.parser_args.ckpt_path config = parser_args.config
audio_file = os.path.abspath(self.parser_args.input) ckpt_path = parser_args.ckpt_path
device = self.parser_args.device audio_file = parser_args.input
device = parser_args.device
try: try:
res = self(model, lang, config, ckpt_path, audio_file, device) res = self(model, lang, model_sample_rate, config, ckpt_path,
audio_file, device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
except Exception as e: except Exception as e:
print(e) print(e)
return False return False
def __call__(self, model, lang, config, ckpt_path, audio_file, device): def __call__(self, model, lang, model_sample_rate, config, ckpt_path,
audio_file, device):
""" """
Python API to call an executor. Python API to call an executor.
""" """
self._init_from_path(model, lang, config, ckpt_path) audio_file = os.path.abspath(audio_file)
self.preprocess(audio_file) self._check(audio_file, model_sample_rate)
self.infer() self._init_from_path(model, lang, model_sample_rate, config, ckpt_path,
device)
self.preprocess(model, audio_file)
self.infer(model)
res = self.postprocess() # Retrieve result of asr. res = self.postprocess() # Retrieve result of asr.
return res return res

Loading…
Cancel
Save