Merge pull request #1074 from Jackwaterveg/fix_cli

[CLI] remove the os.chdir
pull/1077/head
Hui Zhang 3 years ago committed by GitHub
commit 664ee6167e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import argparse
import io
import os import os
import sys import sys
from typing import List from typing import List
@ -19,10 +20,11 @@ from typing import Optional
from typing import Union from typing import Union
import librosa import librosa
import numpy as np
import paddle import paddle
import soundfile import soundfile
import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
import numpy as np
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
@ -46,6 +48,16 @@ pretrained_models = {
'conf/conformer.yaml', 'conf/conformer.yaml',
'ckpt_path': 'ckpt_path':
'exp/conformer/checkpoints/wenetspeech', 'exp/conformer/checkpoints/wenetspeech',
},
"transformer_zh_16k": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/transformer.model.tar.gz',
'md5':
'4e8b63800c71040b9390b150e2a5d4c4',
'cfg_path':
'conf/transformer.yaml',
'ckpt_path':
'exp/transformer/checkpoints/avg_20',
} }
} }
@ -121,8 +133,7 @@ class ASRExecutor(BaseExecutor):
lang: str='zh', lang: str='zh',
sample_rate: int=16000, sample_rate: int=16000,
cfg_path: Optional[os.PathLike]=None, cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None ckpt_path: Optional[os.PathLike]=None):
):
""" """
Init model and other resources from a specific path. Init model and other resources from a specific path.
""" """
@ -130,10 +141,11 @@ class ASRExecutor(BaseExecutor):
sample_rate_str = '16k' if sample_rate == 16000 else '8k' sample_rate_str = '16k' if sample_rate == 16000 else '8k'
tag = model_type + '_' + lang + '_' + sample_rate_str tag = model_type + '_' + lang + '_' + sample_rate_str
res_path = self._get_pretrained_path(tag) # wenetspeech_zh res_path = self._get_pretrained_path(tag) # wenetspeech_zh
self.res_path = res_path
self.cfg_path = os.path.join(res_path, self.cfg_path = os.path.join(res_path,
pretrained_models[tag]['cfg_path']) pretrained_models[tag]['cfg_path'])
self.ckpt_path = os.path.join(res_path, self.ckpt_path = os.path.join(
pretrained_models[tag]['ckpt_path'] + ".pdparams") res_path, pretrained_models[tag]['ckpt_path'] + ".pdparams")
logger.info(res_path) logger.info(res_path)
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.ckpt_path) logger.info(self.ckpt_path)
@ -147,10 +159,8 @@ class ASRExecutor(BaseExecutor):
self.config = CfgNode(new_allowed=True) self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path) self.config.merge_from_file(self.cfg_path)
self.config.decoding.decoding_method = "attention_rescoring" self.config.decoding.decoding_method = "attention_rescoring"
model_conf = self.config.model
logger.info(model_conf)
with UpdateConfig(model_conf): with UpdateConfig(self.config):
if model_type == "ds2_online" or model_type == "ds2_offline": if model_type == "ds2_online" or model_type == "ds2_offline":
from paddlespeech.s2t.io.collator import SpeechCollator from paddlespeech.s2t.io.collator import SpeechCollator
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
@ -162,24 +172,29 @@ class ASRExecutor(BaseExecutor):
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.collate_fn_test.feature_size self.config.model.input_dim = self.collate_fn_test.feature_size
model_conf.output_dim = text_feature.vocab_size self.config.model.output_dim = text_feature.vocab_size
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
self.config.collator.vocab_filepath = os.path.join( self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath) res_path, self.config.collator.vocab_filepath)
self.config.collator.augmentation_config = os.path.join(
res_path, self.config.collator.augmentation_config)
self.config.collator.spm_model_prefix = os.path.join(
res_path, self.config.collator.spm_model_prefix)
text_feature = TextFeaturizer( text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type, unit_type=self.config.collator.unit_type,
vocab_filepath=self.config.collator.vocab_filepath, vocab_filepath=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix) spm_model_prefix=self.config.collator.spm_model_prefix)
model_conf.input_dim = self.config.collator.feat_dim self.config.model.input_dim = self.config.collator.feat_dim
model_conf.output_dim = text_feature.vocab_size self.config.model.output_dim = text_feature.vocab_size
else: else:
raise Exception("wrong type") raise Exception("wrong type")
self.config.freeze()
# Enter the path of model root # Enter the path of model root
os.chdir(res_path)
model_class = dynamic_import(model_type, model_alias) model_class = dynamic_import(model_type, model_alias)
model_conf = self.config.model
logger.info(model_conf)
model = model_class.from_config(model_conf) model = model_class.from_config(model_conf)
self.model = model self.model = model
self.model.eval() self.model.eval()
@ -212,10 +227,17 @@ class ASRExecutor(BaseExecutor):
elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech": elif model_type == "conformer" or model_type == "transformer" or model_type == "wenetspeech":
logger.info("get the preprocess conf") logger.info("get the preprocess conf")
preprocess_conf = os.path.join( preprocess_conf_file = self.config.collator.augmentation_config
os.path.dirname(os.path.abspath(self.cfg_path)), # redirect the cmvn path
"preprocess.yaml") with io.open(preprocess_conf_file, encoding="utf-8") as f:
preprocess_conf = yaml.safe_load(f)
for idx, process in enumerate(preprocess_conf["process"]):
if process['type'] == "cmvn_json":
preprocess_conf["process"][idx][
"cmvn_path"] = os.path.join(
self.res_path,
preprocess_conf["process"][idx]["cmvn_path"])
break
logger.info(preprocess_conf) logger.info(preprocess_conf)
preprocess_args = {"train": False} preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf) preprocessing = Transformation(preprocess_conf)
@ -310,14 +332,14 @@ class ASRExecutor(BaseExecutor):
return self._outputs["result"] return self._outputs["result"]
def _pcm16to32(self, audio): def _pcm16to32(self, audio):
assert(audio.dtype == np.int16) assert (audio.dtype == np.int16)
audio = audio.astype("float32") audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1)) audio = audio / (2**(bits - 1))
return audio return audio
def _pcm32to16(self, audio): def _pcm32to16(self, audio):
assert(audio.dtype == np.float32) assert (audio.dtype == np.float32)
bits = np.iinfo(np.int16).bits bits = np.iinfo(np.int16).bits
audio = audio * (2**(bits - 1)) audio = audio * (2**(bits - 1))
audio = np.round(audio).astype("int16") audio = np.round(audio).astype("int16")
@ -326,9 +348,7 @@ class ASRExecutor(BaseExecutor):
def _check(self, audio_file: str, sample_rate: int): def _check(self, audio_file: str, sample_rate: int):
self.sample_rate = sample_rate self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000: if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error( logger.error("please input --sr 8000 or --sr 16000")
"please input --sr 8000 or --sr 16000"
)
raise Exception("invalid sample rate") raise Exception("invalid sample rate")
sys.exit(-1) sys.exit(-1)
@ -354,13 +374,11 @@ class ASRExecutor(BaseExecutor):
sys.exit(-1) sys.exit(-1)
logger.info("The sample rate is %d" % audio_sample_rate) logger.info("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate: if audio_sample_rate != self.sample_rate:
logger.warning( logger.warning("The sample rate of the input file is not {}.\n \
"The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \ The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \ If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \ Please input the 16k 16 bit 1 channel wav file. \
" ".format(self.sample_rate, self.sample_rate))
.format(self.sample_rate, self.sample_rate))
while (True): while (True):
logger.info( logger.info(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
@ -398,16 +416,16 @@ class ASRExecutor(BaseExecutor):
device = parser_args.device device = parser_args.device
try: try:
res = self(model, lang, sample_rate, config, ckpt_path, res = self(model, lang, sample_rate, config, ckpt_path, audio_file,
audio_file, device) device)
logger.info('ASR Result: {}'.format(res)) logger.info('ASR Result: {}'.format(res))
return True return True
except Exception as e: except Exception as e:
print(e) print(e)
return False return False
def __call__(self, model, lang, sample_rate, config, ckpt_path, def __call__(self, model, lang, sample_rate, config, ckpt_path, audio_file,
audio_file, device): device):
""" """
Python API to call an executor. Python API to call an executor.
""" """

Loading…
Cancel
Save