|
|
@ -16,14 +16,9 @@ import sys
|
|
|
|
import warnings
|
|
|
|
import warnings
|
|
|
|
from typing import List
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy
|
|
|
|
import uvicorn
|
|
|
|
import uvicorn
|
|
|
|
from fastapi import FastAPI
|
|
|
|
from fastapi import FastAPI
|
|
|
|
from prettytable import PrettyTable
|
|
|
|
|
|
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..executor import BaseExecutor
|
|
|
|
|
|
|
|
from ..util import cli_server_register
|
|
|
|
|
|
|
|
from ..util import stats_wrapper
|
|
|
|
|
|
|
|
from paddlespeech.cli.log import logger
|
|
|
|
from paddlespeech.cli.log import logger
|
|
|
|
from paddlespeech.resource import CommonTaskResource
|
|
|
|
from paddlespeech.resource import CommonTaskResource
|
|
|
|
from paddlespeech.server.engine.engine_pool import init_engine_pool
|
|
|
|
from paddlespeech.server.engine.engine_pool import init_engine_pool
|
|
|
@ -31,6 +26,12 @@ from paddlespeech.server.engine.engine_warmup import warm_up
|
|
|
|
from paddlespeech.server.restful.api import setup_router as setup_http_router
|
|
|
|
from paddlespeech.server.restful.api import setup_router as setup_http_router
|
|
|
|
from paddlespeech.server.utils.config import get_config
|
|
|
|
from paddlespeech.server.utils.config import get_config
|
|
|
|
from paddlespeech.server.ws.api import setup_router as setup_ws_router
|
|
|
|
from paddlespeech.server.ws.api import setup_router as setup_ws_router
|
|
|
|
|
|
|
|
from prettytable import PrettyTable
|
|
|
|
|
|
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ..executor import BaseExecutor
|
|
|
|
|
|
|
|
from ..util import cli_server_register
|
|
|
|
|
|
|
|
from ..util import stats_wrapper
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['ServerExecutor', 'ServerStatsExecutor']
|
|
|
|
__all__ = ['ServerExecutor', 'ServerStatsExecutor']
|
|
|
@ -134,7 +135,7 @@ class ServerStatsExecutor():
|
|
|
|
required=True)
|
|
|
|
required=True)
|
|
|
|
self.task_choices = ['asr', 'tts', 'cls', 'text', 'vector']
|
|
|
|
self.task_choices = ['asr', 'tts', 'cls', 'text', 'vector']
|
|
|
|
self.model_name_format = {
|
|
|
|
self.model_name_format = {
|
|
|
|
'asr': 'Model-Language-Sample Rate',
|
|
|
|
'asr': 'Model-Size-Code Switch-Multilingual-Language-Sample Rate',
|
|
|
|
'tts': 'Model-Language',
|
|
|
|
'tts': 'Model-Language',
|
|
|
|
'cls': 'Model-Sample Rate',
|
|
|
|
'cls': 'Model-Sample Rate',
|
|
|
|
'text': 'Model-Task-Language',
|
|
|
|
'text': 'Model-Task-Language',
|
|
|
@ -145,7 +146,20 @@ class ServerStatsExecutor():
|
|
|
|
fields = self.model_name_format[self.task].split("-")
|
|
|
|
fields = self.model_name_format[self.task].split("-")
|
|
|
|
table = PrettyTable(fields)
|
|
|
|
table = PrettyTable(fields)
|
|
|
|
for key in pretrained_models:
|
|
|
|
for key in pretrained_models:
|
|
|
|
table.add_row(key.split("-"))
|
|
|
|
line = key.split("-")
|
|
|
|
|
|
|
|
if self.task == "asr" and len(line) < len(fields):
|
|
|
|
|
|
|
|
for i in range(len(line), len(fields)):
|
|
|
|
|
|
|
|
line.append("-")
|
|
|
|
|
|
|
|
if "codeswitch" in key:
|
|
|
|
|
|
|
|
line[3], line[1] = line[1].split("_")[0], line[1].split(
|
|
|
|
|
|
|
|
"_")[1:]
|
|
|
|
|
|
|
|
elif "multilingual" in key:
|
|
|
|
|
|
|
|
line[4], line[1] = line[1].split("_")[0], line[1].split(
|
|
|
|
|
|
|
|
"_")[1:]
|
|
|
|
|
|
|
|
tmp = numpy.array(line)
|
|
|
|
|
|
|
|
idx = [0, 5, 3, 4, 1, 2]
|
|
|
|
|
|
|
|
line = tmp[idx]
|
|
|
|
|
|
|
|
table.add_row(line)
|
|
|
|
print(table)
|
|
|
|
print(table)
|
|
|
|
|
|
|
|
|
|
|
|
def execute(self, argv: List[str]) -> bool:
|
|
|
|
def execute(self, argv: List[str]) -> bool:
|
|
|
|