Update asr inference in paddlespeech.cli.

pull/1048/head
KP 4 years ago
parent d28888972f
commit e9798498d6

@ -14,7 +14,6 @@
import os import os
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from typing import Optional
from typing import Union from typing import Union
import paddle import paddle
@ -30,16 +29,16 @@ class BaseExecutor(ABC):
self.output = None self.output = None
@abstractmethod @abstractmethod
def _get_default_cfg_path(self): def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
Returns a default config file path of current task. Download and returns pretrained resources path of current task.
""" """
pass pass
@abstractmethod @abstractmethod
def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None): def _init_from_path(self, *args, **kwargs):
""" """
Init model from a specific config file. Init model and other resources from a specific path.
""" """
pass pass

@ -21,9 +21,21 @@ import paddle
from ..executor import BaseExecutor from ..executor import BaseExecutor
from ..utils import cli_register from ..utils import cli_register
from ..utils import download_and_decompress
from ..utils import logger
from ..utils import MODEL_HOME
__all__ = ['S2TExecutor'] __all__ = ['S2TExecutor']
pretrained_models = {
"wenetspeech_zh": {
'url':
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/conformer.model.tar.gz',
'md5':
'54e7a558a6e020c2f5fb224874943f97',
}
}
@cli_register( @cli_register(
name='paddlespeech.s2t', description='Speech to text infer command.') name='paddlespeech.s2t', description='Speech to text infer command.')
@ -33,11 +45,23 @@ class S2TExecutor(BaseExecutor):
self.parser = argparse.ArgumentParser( self.parser = argparse.ArgumentParser(
prog='paddlespeech.s2t', add_help=True) prog='paddlespeech.s2t', add_help=True)
self.parser.add_argument(
'--model',
type=str,
default='wenetspeech',
help='Choose model type of asr task.')
self.parser.add_argument(
'--lang', type=str, default='zh', help='Choose model language.')
self.parser.add_argument( self.parser.add_argument(
'--config', '--config',
type=str, type=str,
default=None, default=None,
help='Config of s2t task. Use deault config when it is None.') help='Config of s2t task. Use deault config when it is None.')
self.parser.add_argument(
'--ckpt_path',
type=str,
default=None,
help='Checkpoint file of model.')
self.parser.add_argument( self.parser.add_argument(
'--input', type=str, help='Audio file to recognize.') '--input', type=str, help='Audio file to recognize.')
self.parser.add_argument( self.parser.add_argument(
@ -46,16 +70,39 @@ class S2TExecutor(BaseExecutor):
default='cpu', default='cpu',
help='Choose device to execute model inference.') help='Choose device to execute model inference.')
def _get_default_cfg_path(self): def _get_pretrained_path(self, tag: str) -> os.PathLike:
""" """
Returns a default config file path of current task. Download and returns pretrained resources path of current task.
""" """
pass assert tag in pretrained_models, 'Can not find pretrained resources of {}.'.format(
tag)
res_path = os.path.join(MODEL_HOME, tag)
decompressed_path = download_and_decompress(pretrained_models[tag],
res_path)
logger.info(
'Use pretrained model stored in: {}'.format(decompressed_path))
return decompressed_path
def _init_from_cfg(self, cfg_path: Optional[os.PathLike]=None): def _init_from_path(self,
model_type: str='wenetspeech',
lang: str='zh',
cfg_path: Optional[os.PathLike]=None,
ckpt_path: Optional[os.PathLike]=None):
""" """
Init model from a specific config file. Init model and other resources from a specific path.
""" """
if cfg_path is None or ckpt_path is None:
res_path = self._get_pretrained_path(
model_type + '_' + lang) # wenetspeech_zh
cfg_path = os.path.join(res_path, 'conf/conformer.yaml')
ckpt_path = os.path.join(
res_path, 'exp/conformer/checkpoints/wenetspeech.pdparams')
logger.info(res_path)
logger.info(cfg_path)
logger.info(ckpt_path)
# Init body.
pass pass
def preprocess(self, input: Union[str, os.PathLike]): def preprocess(self, input: Union[str, os.PathLike]):
@ -82,17 +129,15 @@ class S2TExecutor(BaseExecutor):
parser_args = self.parser.parse_args(argv) parser_args = self.parser.parse_args(argv)
print(parser_args) print(parser_args)
model = parser_args.model
lang = parser_args.lang
config = parser_args.config config = parser_args.config
ckpt_path = parser_args.ckpt_path
audio_file = parser_args.input audio_file = parser_args.input
device = parser_args.device device = parser_args.device
if config is not None:
assert os.path.isfile(config), 'Config file is not valid.'
else:
config = self._get_default_cfg_path()
try: try:
self._init_from_cfg(config) self._init_from_path(model, lang, config, ckpt_path)
self.preprocess(audio_file) self.preprocess(audio_file)
self.infer() self.infer()
res = self.postprocess() # Retrieve result of s2t. res = self.postprocess() # Retrieve result of s2t.

@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import logging
import os import os
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import List
from paddle.framework import load from paddle.framework import load
from paddle.utils import download from paddle.utils import download
@ -26,6 +27,7 @@ __all__ = [
'get_command', 'get_command',
'download_and_decompress', 'download_and_decompress',
'load_state_dict_from_url', 'load_state_dict_from_url',
'logger',
] ]
@ -53,29 +55,27 @@ def get_command(name: str) -> Any:
return com['_entry'] return com['_entry']
def decompress(file: str): def decompress(file: str) -> os.PathLike:
""" """
Extracts all files from a compressed file. Extracts all files from a compressed file.
""" """
assert os.path.isfile(file), "File: {} not exists.".format(file) assert os.path.isfile(file), "File: {} not exists.".format(file)
download._decompress(file) return download._decompress(file)
def download_and_decompress(archives: List[Dict[str, str]], path: str): def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
""" """
Download archieves and decompress to specific path. Download archieves and decompress to specific path.
""" """
if not os.path.isdir(path): if not os.path.isdir(path):
os.makedirs(path) os.makedirs(path)
for archive in archives:
assert 'url' in archive and 'md5' in archive, \ assert 'url' in archive and 'md5' in archive, \
'Dictionary keys of "url" and "md5" are required in the archive, but got: {list(archieve.keys())}' 'Dictionary keys of "url" and "md5" are required in the archive, but got: {}'.format(list(archive.keys()))
return download.get_path_from_url(archive['url'], path, archive['md5'])
download.get_path_from_url(archive['url'], path, archive['md5'])
def load_state_dict_from_url(url: str, path: str, md5: str=None) -> os.PathLike:
def load_state_dict_from_url(url: str, path: str, md5: str=None):
""" """
Download and load a state dict from url Download and load a state dict from url
""" """
@ -84,3 +84,69 @@ def load_state_dict_from_url(url: str, path: str, md5: str=None):
download.get_path_from_url(url, path, md5) download.get_path_from_url(url, path, md5)
return load(os.path.join(path, os.path.basename(url))) return load(os.path.join(path, os.path.basename(url)))
def _get_user_home():
return os.path.expanduser('~')
def _get_paddlespcceh_home():
if 'PPSPEECH_HOME' in os.environ:
home_path = os.environ['PPSPEECH_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
raise RuntimeError(
'The environment variable PPSPEECH_HOME {} is not a directory.'.
format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddlespeech')
def _get_sub_home(directory):
home = os.path.join(_get_paddlespcceh_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
PPSPEECH_HOME = _get_paddlespcceh_home()
MODEL_HOME = _get_sub_home('models')
class Logger(object):
def __init__(self, name: str=None):
name = 'PaddleSpeech' if not name else name
self.logger = logging.getLogger(name)
log_config = {
'DEBUG': 10,
'INFO': 20,
'TRAIN': 21,
'EVAL': 22,
'WARNING': 30,
'ERROR': 40,
'CRITICAL': 50
}
for key, level in log_config.items():
logging.addLevelName(level, key)
self.__dict__[key.lower()] = functools.partial(self.__call__, level)
self.format = logging.Formatter(
fmt='[%(asctime)-15s] [%(levelname)8s] [%(filename)s] [L%(lineno)d] - %(message)s'
)
self.handler = logging.StreamHandler()
self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler)
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False
def __call__(self, log_level: str, msg: str):
self.logger.log(log_level, msg)
logger = Logger()

Loading…
Cancel
Save