unify name style & frame with abs timestamp

pull/1874/head
Hui Zhang 3 years ago
parent 15b25199c2
commit 7be6b0e8cf

@ -1,77 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import uvicorn
from fastapi import FastAPI
from paddlespeech.server.engine.engine_pool import init_engine_pool
from paddlespeech.server.restful.api import setup_router as setup_http_router
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.ws.api import setup_router as setup_ws_router
app = FastAPI(
title="PaddleSpeech Serving API", description="Api", version="0.0.1")
def init(config):
"""system initialization
Args:
config (CfgNode): config object
Returns:
bool:
"""
# init api
api_list = list(engine.split("_")[0] for engine in config.engine_list)
if config.protocol == "websocket":
api_router = setup_ws_router(api_list)
elif config.protocol == "http":
api_router = setup_http_router(api_list)
else:
raise Exception("unsupported protocol")
app.include_router(api_router)
if not init_engine_pool(config):
return False
return True
def main(args):
"""main function"""
config = get_config(args.config_file)
if init(config):
uvicorn.run(app, host=config.host, port=config.port, debug=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_file",
action="store",
help="yaml file of the app",
default="./conf/application.yaml")
parser.add_argument(
"--log_file",
action="store",
help="log file",
default="./log/paddlespeech.log")
args = parser.parse_args()
main(args)

@ -29,19 +29,19 @@ def setup_router(api_list: List):
"""setup router for fastapi """setup router for fastapi
Args: Args:
api_list (List): [asr, tts, cls] api_list (List): [asr, tts, cls, text, vecotr]
Returns: Returns:
APIRouter APIRouter
""" """
for api_name in api_list: for api_name in api_list:
if api_name == 'asr': if api_name.lower() == 'asr':
_router.include_router(asr_router) _router.include_router(asr_router)
elif api_name == 'tts': elif api_name.lower() == 'tts':
_router.include_router(tts_router) _router.include_router(tts_router)
elif api_name == 'cls': elif api_name.lower() == 'cls':
_router.include_router(cls_router) _router.include_router(cls_router)
elif api_name == 'text': elif api_name.lower() == 'text':
_router.include_router(text_router) _router.include_router(text_router)
elif api_name.lower() == 'vector': elif api_name.lower() == 'vector':
_router.include_router(vec_router) _router.include_router(vec_router)

@ -43,6 +43,7 @@ class TextHttpHandler:
else: else:
self.url = 'http://' + self.server_ip + ":" + str( self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/text' self.port) + '/paddlespeech/text'
logger.info(f"endpoint: {self.url}")
def run(self, text): def run(self, text):
"""Call the text server to process the specific text """Call the text server to process the specific text
@ -107,8 +108,10 @@ class ASRWsAudioHandler:
""" """
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
assert sample_rate == 16000
chunk_size = int(85 * sample_rate / 1000) # 85ms, sample_rate = 16kHz
chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0: if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size padding_len_x = chunk_size - x_len % chunk_size
else: else:
@ -217,6 +220,7 @@ class ASRHttpHandler:
else: else:
self.url = 'http://' + self.server_ip + ":" + str( self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/asr' self.port) + '/paddlespeech/asr'
logger.info(f"endpoint: {self.url}")
def run(self, input, audio_format, sample_rate, lang): def run(self, input, audio_format, sample_rate, lang):
"""Call the http asr to process the audio """Call the http asr to process the audio
@ -275,6 +279,7 @@ class TTSWsHandler:
self.start_play = True self.start_play = True
self.t = threading.Thread(target=self.play_audio) self.t = threading.Thread(target=self.play_audio)
self.max_fail = 50 self.max_fail = 50
logger.info(f"endpoint: {self.url}")
def play_audio(self): def play_audio(self):
while True: while True:
@ -383,6 +388,7 @@ class TTSHttpHandler:
self.start_play = True self.start_play = True
self.t = threading.Thread(target=self.play_audio) self.t = threading.Thread(target=self.play_audio)
self.max_fail = 50 self.max_fail = 50
logger.info(f"endpoint: {self.url}")
def play_audio(self): def play_audio(self):
while True: while True:
@ -483,6 +489,7 @@ class VectorHttpHandler:
else: else:
self.url = 'http://' + self.server_ip + ":" + str( self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector' self.port) + '/paddlespeech/vector'
logger.info(f"endpoint: {self.url}")
def run(self, input, audio_format, sample_rate, task="spk"): def run(self, input, audio_format, sample_rate, task="spk"):
"""Call the http asr to process the audio """Call the http asr to process the audio
@ -529,6 +536,7 @@ class VectorScoreHttpHandler:
else: else:
self.url = 'http://' + self.server_ip + ":" + str( self.url = 'http://' + self.server_ip + ":" + str(
self.port) + '/paddlespeech/vector/score' self.port) + '/paddlespeech/vector/score'
logger.info(f"endpoint: {self.url}")
def run(self, enroll_audio, test_audio, audio_format, sample_rate): def run(self, enroll_audio, test_audio, audio_format, sample_rate):
"""Call the http asr to process the audio """Call the http asr to process the audio

@ -107,7 +107,7 @@ def change_speed(sample_raw, speed_rate, sample_rate):
def float2pcm(sig, dtype='int16'): def float2pcm(sig, dtype='int16'):
"""Convert floating point signal with a range from -1 to 1 to PCM. """Convert floating point signal with a range from -1 to 1 to PCM16.
Args: Args:
sig (array): Input array, must have floating point type. sig (array): Input array, must have floating point type.

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class Frame(object): class Frame(object):
"""Represents a "frame" of audio data.""" """Represents a "frame" of audio data."""
@ -46,8 +45,7 @@ class ChunkBuffer(object):
self.shift_ms = shift_ms self.shift_ms = shift_ms
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4 self.sample_width = sample_width # int16 = 2; float32 = 4
self.remained_audio = b''
self.window_sec = float((self.window_n - 1) * self.shift_ms + self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0 self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0) self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
@ -57,22 +55,31 @@ class ChunkBuffer(object):
self.shift_bytes = int(self.shift_sec * self.sample_rate * self.shift_bytes = int(self.shift_sec * self.sample_rate *
self.sample_width) self.sample_width)
self.remained_audio = b''
# abs timestamp from `start` or latest `reset`
self.timestamp = 0.0
def reset(self):
"""
reset buffer state.
"""
self.timestamp = 0.0
self.remained_audio = b''
def frame_generator(self, audio): def frame_generator(self, audio):
"""Generates audio frames from PCM audio data. """Generates audio frames from PCM audio data.
Takes the desired frame duration in milliseconds, the PCM data, and Takes the desired frame duration in milliseconds, the PCM data, and
the sample rate. the sample rate.
Yields Frames of the requested duration. Yields Frames of the requested duration.
""" """
audio = self.remained_audio + audio audio = self.remained_audio + audio
self.remained_audio = b'' self.remained_audio = b''
offset = 0 offset = 0
timestamp = 0.0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], timestamp, yield Frame(audio[offset:offset + self.window_bytes], self.timestamp,
self.window_sec) self.window_sec)
timestamp += self.shift_sec self.timestamp += self.shift_sec
offset += self.shift_bytes offset += self.shift_bytes
self.remained_audio += audio[offset:] self.remained_audio += audio[offset:]

@ -15,8 +15,8 @@ from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.server.ws.asr_socket import router as asr_router from paddlespeech.server.ws.asr_api import router as asr_router
from paddlespeech.server.ws.tts_socket import router as tts_router from paddlespeech.server.ws.tts_api import router as tts_router
_router = APIRouter() _router = APIRouter()

@ -1,110 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.engine.engine_pool import get_engine_pool
router = APIRouter()
@router.websocket('/paddlespeech/asr/streaming')
async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online ASR Server api
Args:
websocket (WebSocket): the websocket instance
"""
#1. the interface wait to accept the websocket protocal header
# and only we receive the header, it establish the connection with specific thread
await websocket.accept()
#2. if we accept the websocket headers, we will get the online asr engine instance
engine_pool = get_engine_pool()
asr_engine = engine_pool['asr']
#3. each websocket connection, we will create an PaddleASRConnectionHanddler to process such audio
# and each connection has its own connection instance to process the request
# and only if client send the start signal, we create the PaddleASRConnectionHanddler instance
connection_handler = None
try:
#4. we do a loop to process the audio package by package according the protocal
# and only if the client send finished signal, we will break the loop
while True:
# careful here, changed the source code from starlette.websockets
# 4.1 we wait for the client signal for the specific action
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
#4.2 text for the action command and bytes for pcm data
if "text" in message:
# we first parse the specific command
message = json.loads(message["text"])
if 'signal' not in message:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
# start command, we create the PaddleASRConnectionHanddler instance to process the audio data
# end command, we process the all the last audio pcm and return the final result
# and we break the loop
if message['signal'] == 'start':
resp = {"status": "ok", "signal": "server_ready"}
# do something at begining here
# create the instance to process the audio
connection_handler = PaddleASRConnectionHanddler(asr_engine)
await websocket.send_json(resp)
elif message['signal'] == 'end':
# reset single engine for an new connection
# and we will destroy the connection
connection_handler.decode(is_finished=True)
connection_handler.rescoring()
asr_results = connection_handler.get_result()
word_time_stamp = connection_handler.get_word_time_stamp()
connection_handler.reset()
resp = {
"status": "ok",
"signal": "finished",
'result': asr_results,
'times': word_time_stamp
}
await websocket.send_json(resp)
break
else:
resp = {"status": "ok", "message": "no valid json data"}
await websocket.send_json(resp)
elif "bytes" in message:
# bytes for the pcm data
message = message["bytes"]
# we extract the remained audio pcm
# and decode for the result in this package data
connection_handler.extract_feat(message)
connection_handler.decode(is_finished=False)
asr_results = connection_handler.get_result()
# return the current period result
# if the engine create the vad instance, this connection will have many period results
resp = {'result': asr_results}
await websocket.send_json(resp)
except WebSocketDisconnect:
pass

@ -1,61 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from fastapi import APIRouter
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.cli.log import logger
from paddlespeech.server.engine.engine_pool import get_engine_pool
router = APIRouter()
@router.websocket('/paddlespeech/tts/streaming')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
# careful here, changed the source code from starlette.websockets
assert websocket.application_state == WebSocketState.CONNECTED
message = await websocket.receive()
websocket._raise_on_disconnect(message)
# get engine
engine_pool = get_engine_pool()
tts_engine = engine_pool['tts']
# 获取 message 并转文本
message = json.loads(message["text"])
text_bese64 = message["text"]
sentence = tts_engine.preprocess(text_bese64=text_bese64)
# run
wav_generator = tts_engine.run(sentence)
while True:
try:
tts_results = next(wav_generator)
resp = {"status": 1, "audio": tts_results}
await websocket.send_json(resp)
except StopIteration as e:
resp = {"status": 2, "audio": ''}
await websocket.send_json(resp)
logger.info("Complete the transmission of audio streams")
break
except WebSocketDisconnect:
pass
Loading…
Cancel
Save