code format, test=doc

pull/1652/head
lym0302 3 years ago
parent 759a9e61e4
commit 1a3c811f04

@ -27,6 +27,7 @@ from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import pcm2float
from paddlespeech.server.utils.paddle_predictor import init_predictor
__all__ = ['ASREngine']
@ -222,21 +223,6 @@ class ASRServerExecutor(ASRExecutor):
else:
raise Exception("invalid model name")
def _pcm16to32(self, audio):
"""pcm int16 to float32
Args:
audio(numpy.array): numpy.int16
Returns:
audio(numpy.array): numpy.float32
"""
if audio.dtype == np.int16:
audio = audio.astype("float32")
bits = np.iinfo(np.int16).bits
audio = audio / (2**(bits - 1))
return audio
def extract_feat(self, samples, sample_rate):
"""extract feat
@ -249,7 +235,7 @@ class ASRServerExecutor(ASRExecutor):
x_chunk_lens (numpy.array): shape[B]
"""
# pcm16 -> pcm 32
samples = self._pcm16to32(samples)
samples = pcm2float(samples)
# read audio
speech_segment = SpeechSegment.from_pcm(

@ -12,29 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import time
import librosa
import numpy as np
import paddle
import soundfile as sf
from scipy.io import wavfile
from paddlespeech.cli.log import logger
from paddlespeech.cli.tts.infer import TTSExecutor
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils.audio_process import change_speed
from paddlespeech.server.utils.errors import ErrorCode
from paddlespeech.server.utils.exception import ServerBaseException
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.config import get_config
from paddlespeech.server.utils.util import denorm
from paddlespeech.server.utils.util import get_chunks
import math
__all__ = ['TTSEngine']
@ -44,7 +32,8 @@ class TTSServerExecutor(TTSExecutor):
pass
@paddle.no_grad()
def infer(self,
def infer(
self,
text: str,
lang: str='zh',
am: str='fastspeech2_csmsc',
@ -61,8 +50,6 @@ class TTSServerExecutor(TTSExecutor):
get_tone_ids = False
merge_sentences = False
frontend_st = time.time()
if am_name == 'speedyspeech':
get_tone_ids = True
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
text,
@ -110,10 +97,12 @@ class TTSServerExecutor(TTSExecutor):
elif i == chunk_num - 1:
sub_wav = sub_wav[front_pad * voc_upsample:]
else:
sub_wav = sub_wav[front_pad * voc_upsample: (front_pad + voc_block) * voc_upsample]
sub_wav = sub_wav[front_pad * voc_upsample:(
front_pad + voc_block) * voc_upsample]
yield sub_wav
class TTSEngine(BaseEngine):
"""TTS server engine
@ -128,9 +117,11 @@ class TTSEngine(BaseEngine):
def init(self, config: dict) -> bool:
self.executor = TTSServerExecutor()
try:
self.config = config
assert "fastspeech2_csmsc" in config.am and (
config.voc == "hifigan_csmsc-zh" or config.voc == "mb_melgan_csmsc"
), 'Please check config, am support: fastspeech2, voc support: hifigan_csmsc-zh or mb_melgan_csmsc.'
try:
if self.config.device:
self.device = self.config.device
else:
@ -181,81 +172,6 @@ class TTSEngine(BaseEngine):
return text
def postprocess(self,
wav,
original_fs: int,
target_fs: int=0,
volume: float=1.0,
speed: float=1.0,
audio_path: str=None):
"""Post-processing operations, including speech, volume, sample rate, save audio file
Args:
wav (numpy(float)): Synthesized audio sample points
original_fs (int): original audio sample rate
target_fs (int): target audio sample rate
volume (float): target volume
speed (float): target speed
Raises:
ServerBaseException: Throws an exception if the change speed unsuccessfully.
Returns:
target_fs: target sample rate for synthesized audio.
wav_base64: The base64 format of the synthesized audio.
"""
# transform sample_rate
if target_fs == 0 or target_fs > original_fs:
target_fs = original_fs
wav_tar_fs = wav
logger.info(
"The sample rate of synthesized audio is the same as model, which is {}Hz".
format(original_fs))
else:
wav_tar_fs = librosa.resample(
np.squeeze(wav), original_fs, target_fs)
logger.info(
"The sample rate of model is {}Hz and the target sample rate is {}Hz. Converting the sample rate of the synthesized audio successfully.".
format(original_fs, target_fs))
# transform volume
wav_vol = wav_tar_fs * volume
logger.info("Transform the volume of the audio successfully.")
# transform speed
try: # windows not support soxbindings
wav_speed = change_speed(wav_vol, speed, target_fs)
logger.info("Transform the speed of the audio successfully.")
except ServerBaseException:
raise ServerBaseException(
ErrorCode.SERVER_INTERNAL_ERR,
"Failed to transform speed. Can not install soxbindings on your system. \
You need to set speed value 1.0.")
except BaseException:
logger.error("Failed to transform speed.")
# wav to base64
buf = io.BytesIO()
wavfile.write(buf, target_fs, wav_speed)
base64_bytes = base64.b64encode(buf.read())
wav_base64 = base64_bytes.decode('utf-8')
logger.info("Audio to string successfully.")
# save audio
if audio_path is not None:
if audio_path.endswith(".wav"):
sf.write(audio_path, wav_speed, target_fs)
elif audio_path.endswith(".pcm"):
wav_norm = wav_speed * (32767 / max(0.001,
np.max(np.abs(wav_speed))))
with open(audio_path, "wb") as f:
f.write(wav_norm.astype(np.int16))
logger.info("Save audio to {} successfully.".format(audio_path))
else:
logger.info("There is no need to save audio.")
return target_fs, wav_base64
def run(self,
sentence: str,
spk_id: int=0,
@ -275,20 +191,22 @@ class TTSEngine(BaseEngine):
save_path (str, optional): The save path of the synthesized audio.
None means do not save audio. Defaults to None.
Raises:
ServerBaseException: Throws an exception if tts inference unsuccessfully.
ServerBaseException: Throws an exception if postprocess unsuccessfully.
Returns:
lang: model language
target_sample_rate: target sample rate for synthesized audio.
wav_base64: The base64 format of the synthesized audio.
"""
lang = self.config.lang
wav_list = []
for wav in self.executor.infer(text=sentence, lang=lang, am=self.config.am, spk_id=spk_id, am_block=self.am_block, am_pad=self.am_pad, voc_block=self.voc_block, voc_pad=self.voc_pad):
for wav in self.executor.infer(
text=sentence,
lang=lang,
am=self.config.am,
spk_id=spk_id,
am_block=self.am_block,
am_pad=self.am_pad,
voc_block=self.voc_block,
voc_pad=self.voc_pad):
# wav type: <class 'numpy.ndarray'> float32, convert to pcm (base64)
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
@ -298,8 +216,5 @@ class TTSEngine(BaseEngine):
yield wav_base64
wav_all = np.concatenate(wav_list, axis=0)
logger.info("The durations of audio is: {} s".format(len(wav_all)/self.executor.am_config.fs))
logger.info("The durations of audio is: {} s".format(
len(wav_all) / self.executor.am_config.fs))

@ -25,7 +25,7 @@ st = 0.0
all_bytes = b''
class Ws_Param(object):
class WsParam(object):
# 初始化
def __init__(self, text, server="127.0.0.1", port=8090):
self.server = server
@ -116,7 +116,7 @@ if __name__ == "__main__":
print("Sentence to be synthesized: ", args.text)
print("***************************************")
wsParam = Ws_Param(text=args.text, server=args.server, port=args.port)
wsParam = WsParam(text=args.text, server=args.server, port=args.port)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()

@ -32,7 +32,7 @@ st = 0.0
all_bytes = 0.0
class Ws_Param(object):
class WsParam(object):
# 初始化
def __init__(self, text, server="127.0.0.1", port=8090):
self.server = server
@ -144,7 +144,7 @@ if __name__ == "__main__":
print("Sentence to be synthesized: ", args.text)
print("***************************************")
wsParam = Ws_Param(text=args.text, server=args.server, port=args.port)
wsParam = WsParam(text=args.text, server=args.server, port=args.port)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()

@ -126,3 +126,17 @@ def float2pcm(sig, dtype='int16'):
abs_max = 2**(i.bits - 1)
offset = i.min + abs_max
return (sig * abs_max + offset).clip(i.min, i.max).astype(dtype)
def pcm2float(data):
"""pcm int16 to float32
Args:
audio(numpy.array): numpy.int16
Returns:
audio(numpy.array): numpy.float32
"""
if data.dtype == np.int16:
data = data.astype("float32")
bits = np.iinfo(np.int16).bits
data = data / (2**(bits - 1))
return data

@ -35,10 +35,23 @@ def self_check():
def denorm(data, mean, std):
"""stream am model need to denorm
"""
return data * std + mean
def get_chunks(data, block_size, pad_size, step):
"""Divide data into multiple chunks
Args:
data (tensor): data
block_size (int): [description]
pad_size (int): [description]
step (str): set "am" or "voc", generate chunk for step am or vocoder(voc)
Returns:
list: chunks list
"""
if step == "am":
data_len = data.shape[1]
elif step == "voc":

@ -44,11 +44,11 @@ async def websocket_endpoint(websocket: WebSocket):
sentence = tts_engine.preprocess(text_bese64=text_bese64)
# run
wav = tts_engine.run(sentence)
wav_generator = tts_engine.run(sentence)
while True:
try:
tts_results = next(wav)
tts_results = next(wav_generator)
resp = {"status": 1, "audio": tts_results}
await websocket.send_json(resp)
logger.info("streaming audio...")

Loading…
Cancel
Save