Merge pull request #1874 from zh794390558/server

[server] unify name style & frame with abs timestamp
pull/1878/head
Hui Zhang 2 years ago committed by GitHub
commit 14ad165c7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,77 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import uvicorn
from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
def init(config):
"""system initialization
Args:
config (CfgNode): config object
Returns:
bool:
"""
# init api
api_list = list(engine.split("_")[0] for engine in config.engine_list)
if config.protocol == "websocket":
api_router = setup_ws_router(api_list)
elif config.protocol == "http":
api_router = setup_http_router(api_list)
else:
raise Exception("unsupported protocol")
app.include_router(api_router)
if not init_engine_pool(config):
return False
return True
def main(args):
"""main function"""
config = get_config(args.config_file)
if init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args()
main(args)

@ -29,19 +29,19 @@ def setup_router(api_list: List):
"""setup router for fastapi """setup router for fastapi
Args: Args:
api_list (List): [asr, tts, cls] api_list (List): [asr, tts, cls, text, vecotr]
Returns: Returns:
APIRouter APIRouter
""" """
for api_name in api_list: for api_name in api_list:
if api_name == 'asr': if api_name.lower() == 'asr':
_router.include_router(asr_router) _router.include_router(asr_router)
elif api_name == 'tts': elif api_name.lower() == 'tts':
_router.include_router(tts_router) _router.include_router(tts_router)
elif api_name == 'cls': elif api_name.lower() == 'cls':
_router.include_router(cls_router) _router.include_router(cls_router)
elif api_name == 'text': elif api_name.lower() == 'text':
_router.include_router(text_router) _router.include_router(text_router)
elif api_name.lower() == 'vector': elif api_name.lower() == 'vector':
_router.include_router(vec_router) _router.include_router(vec_router)

@ -43,6 +43,7 @@ class TextHttpHandler:
else: else:
self.url = 'http://' + self.server_ip + ":" + str( self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/text' self.port) + '/paddlespeech/text'
logger.info(f"endpoint: {self.url}")
def run(self, text): def run(self, text):
"""Call the text server to process the specific text """Call the text server to process the specific text
@ -107,8 +108,10 @@ class ASRWsAudioHandler:
""" """
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
assert sample_rate == 16000
chunk_size = int(85 * sample_rate / 1000) # 85ms, sample_rate = 16kHz
chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0: if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size padding_len_x = chunk_size - x_len % chunk_size
else: else:
@ -217,6 +220,7 @@ class ASRHttpHandler:
else: else:
self.url = 'http://' + self.server_ip + ":" + str( self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/asr' self.port) + '/paddlespeech/asr'
logger.info(f"endpoint: {self.url}")
def run(self, input, audio_format, sample_rate, lang): def run(self, input, audio_format, sample_rate, lang):
"""Call the http asr to process the audio """Call the http asr to process the audio
@ -275,6 +279,7 @@ class TTSWsHandler:
self.start_play = True self.start_play = True
self.t = threading.Thread(target=self.play_audio) self.t = threading.Thread(target=self.play_audio)
self.max_fail = 50 self.max_fail = 50
logger.info(f"endpoint: {self.url}")
def play_audio(self): def play_audio(self):
while True: while True:
@ -383,6 +388,7 @@ class TTSHttpHandler:
self.start_play = True self.start_play = True
self.t = threading.Thread(target=self.play_audio) self.t = threading.Thread(target=self.play_audio)
self.max_fail = 50 self.max_fail = 50
logger.info(f"endpoint: {self.url}")
def play_audio(self): def play_audio(self):
while True: while True:
@ -483,6 +489,7 @@ class VectorHttpHandler:
else: else:
self.url = 'http://' + self.server_ip + ":" + str( self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector' self.port) + '/paddlespeech/vector'
logger.info(f"endpoint: {self.url}")
def run(self, input, audio_format, sample_rate, task="spk"): def run(self, input, audio_format, sample_rate, task="spk"):
"""Call the http asr to process the audio """Call the http asr to process the audio
@ -529,6 +536,7 @@ class VectorScoreHttpHandler:
else: else:
self.url = 'http://' + self.server_ip + ":" + str( self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector/score' self.port) + '/paddlespeech/vector/score'
logger.info(f"endpoint: {self.url}")
def run(self, enroll_audio, test_audio, audio_format, sample_rate): def run(self, enroll_audio, test_audio, audio_format, sample_rate):
"""Call the http asr to process the audio """Call the http asr to process the audio

@ -107,7 +107,7 @@ def change_speed(sample_raw, speed_rate, sample_rate):
def float2pcm(sig, dtype='int16'): def float2pcm(sig, dtype='int16'):
"""Convert floating point signal with a range from -1 to 1 to PCM. """Convert floating point signal with a range from -1 to 1 to PCM16.
Args: Args:
sig (array): Input array, must have floating point type. sig (array): Input array, must have floating point type.

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class Frame(object): class Frame(object):
"""Represents a "frame" of audio data.""" """Represents a "frame" of audio data."""
@ -46,8 +45,7 @@ class ChunkBuffer(object):
self.shift_ms = shift_ms self.shift_ms = shift_ms
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4 self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms + self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0 self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0) self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
@ -57,22 +55,31 @@ class ChunkBuffer(object):
self.shift_bytes = int(self.shift_sec * self.sample_rate * self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width) self.sample_width)
self.remained_audio = b''
# abs timestamp from `start` or latest `reset`
self.timestamp = 0.0
def reset(self):
"""
reset buffer state.
"""
self.timestamp = 0.0
self.remained_audio = b''
def frame_generator(self, audio): def frame_generator(self, audio):
"""Generates audio frames from PCM audio data. """Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate. the sample rate.
Yields Frames of the requested duration. Yields Frames of the requested duration.
""" """
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
offset = 0 offset = 0
timestamp = 0.0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp, yield Frame(audio[offset:offset + self.window_bytes], self.timestamp,
self.window_sec) self.window_sec)
timestamp += self.shift_sec self.timestamp += self.shift_sec
offset += self.shift_bytes offset += self.shift_bytes
self.remained_audio += audio[offset:] self.remained_audio += audio[offset:]

@ -15,8 +15,8 @@ from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.server.ws.asr_socket import router as asr_router from paddlespeech.server.ws.asr_api import router as asr_router
from paddlespeech.server.ws.tts_socket import router as tts_router from paddlespeech.server.ws.tts_api import router as tts_router
_router = APIRouter() _router = APIRouter()

@ -58,4 +58,4 @@ async def websocket_endpoint(websocket: WebSocket):
break break
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
Loading…
Cancel
Save