diff --git a/demos/speech_recognition/run.sh b/demos/speech_recognition/run.sh index 5eacce9d9..8ba6e4c3e 100755 --- a/demos/speech_recognition/run.sh +++ b/demos/speech_recognition/run.sh @@ -1,5 +1,4 @@ #!/bin/bash -#!/bin/bash wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 3f8e5f65e..7a7aef8b0 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -62,7 +62,7 @@ class ASRExecutor(BaseExecutor): '--lang', type=str, default='zh', - help='Choose model language. [zh, en, zh_en], zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k], zh_en:[conformer_talcs-zh_en-16k-codeswitch]' + help='Choose model language. [zh, en, zh_en], zh:[conformer_wenetspeech-zh-16k], en:[transformer_librispeech-en-16k], zh_en:[conformer_talcs-codeswitch_zh_en-16k]' ) self.parser.add_argument( '--codeswitch', @@ -151,7 +151,9 @@ class ASRExecutor(BaseExecutor): if cfg_path is None or ckpt_path is None: sample_rate_str = '16k' if sample_rate == 16000 else '8k' if lang == "zh_en" and codeswitch is True: - tag = model_type + '-' + lang + '-' + sample_rate_str + '-' + 'codeswitch' + tag = model_type + '-' + 'codeswitch_' + lang + '-' + sample_rate_str + elif lang == "zh_en" or codeswitch is True: + raise Exception("codeswitch is true only in zh_en model") else: tag = model_type + '-' + lang + '-' + sample_rate_str self.task_resource.set_task_model(tag, version=None) diff --git a/paddlespeech/cli/base_commands.py b/paddlespeech/cli/base_commands.py index efcd671d0..dfeb5cae5 100644 --- a/paddlespeech/cli/base_commands.py +++ b/paddlespeech/cli/base_commands.py @@ -116,6 +116,12 @@ class StatsCommand: if self.task == "asr" and len(line) < len(fields): for i in range(len(line), len(fields)): line.append("-") + if "codeswitch" in key: + line[3], line[1] = line[1].split("_")[0], line[1].split( + "_")[1:] + elif "multilingual" in key: + line[4], line[1] = line[1].split("_")[0], line[1].split( + "_")[1:] tmp = numpy.array(line) idx = [0, 5, 3, 4, 1, 2] line = tmp[idx] diff --git a/paddlespeech/resource/pretrained_models.py b/paddlespeech/resource/pretrained_models.py index 5da97692f..ff0b30f6d 100644 --- a/paddlespeech/resource/pretrained_models.py +++ b/paddlespeech/resource/pretrained_models.py @@ -30,6 +30,7 @@ __all__ = [ ] # The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]". +# Add code-switch and multilingual tag, "{model_name}[_{dataset}]-[codeswitch/multilingual][_{lang}][-...]". # e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k". # Command line and python api use "{model_name}[_{dataset}]" as --model, usage: # "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav" @@ -322,7 +323,7 @@ asr_dynamic_pretrained_models = { '099a601759d467cd0a8523ff939819c5' }, }, - "conformer_talcs-zh_en-16k-codeswitch": { + "conformer_talcs-codeswitch_zh_en-16k": { '1.4': { 'url': 'https://paddlespeech.bj.bcebos.com/s2t/tal_cs/asr1/asr1_conformer_talcs_ckpt_1.4.0.model.tar.gz', diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index 1b1792bd1..299a8c3d4 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -16,14 +16,9 @@ import sys import warnings from typing import List +import numpy import uvicorn from fastapi import FastAPI -from prettytable import PrettyTable -from starlette.middleware.cors import CORSMiddleware - -from ..executor import BaseExecutor -from ..util import cli_server_register -from ..util import stats_wrapper from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource from paddlespeech.server.engine.engine_pool import init_engine_pool @@ -31,6 +26,12 @@ from paddlespeech.server.engine.engine_warmup import warm_up from paddlespeech.server.restful.api import setup_router as setup_http_router from paddlespeech.server.utils.config import get_config from paddlespeech.server.ws.api import setup_router as setup_ws_router +from prettytable import PrettyTable +from starlette.middleware.cors import CORSMiddleware + +from ..executor import BaseExecutor +from ..util import cli_server_register +from ..util import stats_wrapper warnings.filterwarnings("ignore") __all__ = ['ServerExecutor', 'ServerStatsExecutor'] @@ -134,7 +135,7 @@ class ServerStatsExecutor(): required=True) self.task_choices = ['asr', 'tts', 'cls', 'text', 'vector'] self.model_name_format = { - 'asr': 'Model-Language-Sample Rate', + 'asr': 'Model-Size-Code Switch-Multilingual-Language-Sample Rate', 'tts': 'Model-Language', 'cls': 'Model-Sample Rate', 'text': 'Model-Task-Language', @@ -145,7 +146,20 @@ class ServerStatsExecutor(): fields = self.model_name_format[self.task].split("-") table = PrettyTable(fields) for key in pretrained_models: - table.add_row(key.split("-")) + line = key.split("-") + if self.task == "asr" and len(line) < len(fields): + for i in range(len(line), len(fields)): + line.append("-") + if "codeswitch" in key: + line[3], line[1] = line[1].split("_")[0], line[1].split( + "_")[1:] + elif "multilingual" in key: + line[4], line[1] = line[1].split("_")[0], line[1].split( + "_")[1:] + tmp = numpy.array(line) + idx = [0, 5, 3, 4, 1, 2] + line = tmp[idx] + table.add_row(line) print(table) def execute(self, argv: List[str]) -> bool: diff --git a/tests/unit/cli/test_cli.sh b/tests/unit/cli/test_cli.sh index 3a58626d2..5d3b76f6c 100755 --- a/tests/unit/cli/test_cli.sh +++ b/tests/unit/cli/test_cli.sh @@ -14,7 +14,7 @@ paddlespeech ssl --task asr --lang en --input ./en.wav paddlespeech ssl --task vector --lang en --input ./en.wav # Speech_recognition -wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav +wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/zh.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav https://paddlespeech.bj.bcebos.com/PaddleAudio/ch_zh_mix.wav paddlespeech asr --input ./zh.wav paddlespeech asr --model conformer_aishell --input ./zh.wav paddlespeech asr --model conformer_online_aishell --input ./zh.wav @@ -26,6 +26,7 @@ paddlespeech asr --model deepspeech2offline_aishell --input ./zh.wav paddlespeech asr --model deepspeech2online_wenetspeech --input ./zh.wav paddlespeech asr --model deepspeech2online_aishell --input ./zh.wav paddlespeech asr --model deepspeech2offline_librispeech --lang en --input ./en.wav +paddlespeech asr --model conformer_talcs --lang zh_en --codeswitch True --input ./ch_zh_mix.wav # Support editing num_decoding_left_chunks paddlespeech asr --model conformer_online_wenetspeech --num_decoding_left_chunks 3 --input ./zh.wav