From deab2c75ba6ff6583ef71019cc5ed485f091f421 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=95=E5=BF=97=E8=BD=A9?= <8252801+lv-zhixuan@user.noreply.gitee.com> Date: Sat, 24 Sep 2022 11:18:02 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=BE=97=E5=88=B024k?= =?UTF-8?q?=E7=9A=84transformer=E5=A3=B0=E7=A0=81=E5=99=A8=EF=BC=8C?= =?UTF-8?q?=E5=B9=B6=E6=8C=89=E6=A0=87=E5=87=86=E5=8C=96=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86preprocess=E5=92=8Csynthesize?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .idea/.gitignore | 3 + .idea/PaddleSpeech.iml | 12 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/misc.xml | 4 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + demos/audio_searching/src/operations/load.py | 5 +- demos/speech_web/API.md | 2 +- demos/speech_web/speech_server/main.py | 163 +++++----- .../speech_web/speech_server/requirements.txt | 11 +- .../speech_server/src/AudioManeger.py | 88 +++--- .../speech_server/src/SpeechBase/asr.py | 17 +- .../speech_server/src/SpeechBase/nlp.py | 18 +- .../src/SpeechBase/sql_helper.py | 49 +-- .../speech_server/src/SpeechBase/tts.py | 93 +++--- .../speech_server/src/SpeechBase/vpr.py | 55 ++-- .../src/SpeechBase/vpr_encode.py | 9 +- .../speech_server/src/WebsocketManeger.py | 3 +- demos/speech_web/speech_server/src/robot.py | 45 +-- demos/speech_web/speech_server/src/util.py | 17 +- .../local/rtf_from_log.py | 2 +- docs/requirements.txt | 35 ++- docs/source/conf.py | 5 +- examples/csmsc/tts1/README.md | 16 +- examples/csmsc/tts1/conf/default.yaml | 19 +- examples/csmsc/tts1/local/preprocess.sh | 20 +- examples/csmsc/tts1/local/synthesize.sh | 100 +++++- examples/csmsc/tts1/local/synthesize_e2e.sh | 114 ++++++- examples/ernie_sat/local/inference.py | 6 +- examples/ernie_sat/local/inference_new.py | 8 +- paddlespeech/__init__.py | 2 - paddlespeech/audio/__init__.py | 6 +- paddlespeech/audio/streamdata/__init__.py | 125 ++++---- paddlespeech/audio/streamdata/autodecode.py | 19 +- paddlespeech/audio/streamdata/cache.py | 64 ++-- paddlespeech/audio/streamdata/compat.py | 65 ++-- .../audio/streamdata/extradatasets.py | 6 +- paddlespeech/audio/streamdata/filters.py | 247 +++++++++------ paddlespeech/audio/streamdata/gopen.py | 62 ++-- paddlespeech/audio/streamdata/handlers.py | 5 +- paddlespeech/audio/streamdata/mix.py | 17 +- paddlespeech/audio/streamdata/paddle_utils.py | 3 +- paddlespeech/audio/streamdata/pipeline.py | 15 +- paddlespeech/audio/streamdata/shardlists.py | 75 +++-- paddlespeech/audio/streamdata/tariterators.py | 81 ++--- paddlespeech/audio/streamdata/utils.py | 31 +- paddlespeech/audio/streamdata/writer.py | 77 ++--- paddlespeech/audio/text/text_featurizer.py | 2 +- paddlespeech/audio/transform/perturb.py | 11 +- paddlespeech/audio/transform/spec_augment.py | 1 + paddlespeech/cli/executor.py | 2 +- paddlespeech/s2t/__init__.py | 1 + paddlespeech/s2t/exps/u2/model.py | 23 +- paddlespeech/s2t/exps/u2_kaldi/model.py | 26 +- paddlespeech/s2t/exps/u2_st/model.py | 19 +- paddlespeech/s2t/io/dataloader.py | 145 +++++---- paddlespeech/s2t/io/sampler.py | 2 +- paddlespeech/s2t/models/u2_st/u2_st.py | 12 +- paddlespeech/s2t/modules/align.py | 39 ++- paddlespeech/s2t/modules/attention.py | 23 +- .../s2t/modules/conformer_convolution.py | 17 +- paddlespeech/s2t/modules/encoder.py | 29 +- paddlespeech/s2t/modules/initializer.py | 1 + .../server/bin/paddlespeech_server.py | 2 +- .../server/engine/asr/online/ctc_endpoint.py | 6 +- .../engine/asr/online/onnx/asr_engine.py | 2 +- .../asr/online/paddleinference/asr_engine.py | 2 +- .../engine/asr/online/python/asr_engine.py | 13 +- paddlespeech/t2s/exps/ernie_sat/align.py | 9 +- .../t2s/exps/ernie_sat/synthesize_e2e.py | 68 ++--- paddlespeech/t2s/exps/ernie_sat/utils.py | 11 +- paddlespeech/t2s/exps/tacotron2/normalize.py | 2 +- .../exps/transformer_tts/preprocess_new.py | 284 ++++++------------ .../t2s/exps/transformer_tts/synthesize.py | 52 ++-- .../exps/transformer_tts/synthesize_e2e.py | 76 +++-- .../t2s/modules/transformer/repeat.py | 2 +- setup.py | 7 +- .../ds2_ol/onnx/local/onnx_infer_shape.py | 31 +- 78 files changed, 1551 insertions(+), 1208 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/PaddleSpeech.iml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 000000000..26d33521a --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/PaddleSpeech.iml b/.idea/PaddleSpeech.iml new file mode 100644 index 000000000..8a05c6ed5 --- /dev/null +++ b/.idea/PaddleSpeech.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 000000000..105ce2da2 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 000000000..a2e120dcc --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 000000000..d4f59faea --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 000000000..94a25f7f4 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/demos/audio_searching/src/operations/load.py b/demos/audio_searching/src/operations/load.py index 0d9edb784..d1ea00576 100644 --- a/demos/audio_searching/src/operations/load.py +++ b/demos/audio_searching/src/operations/load.py @@ -26,8 +26,9 @@ def get_audios(path): """ supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] return [ - item for sublist in [[os.path.join(dir, file) for file in files] - for dir, _, files in list(os.walk(path))] + item + for sublist in [[os.path.join(dir, file) for file in files] + for dir, _, files in list(os.walk(path))] for item in sublist if os.path.splitext(item)[1] in supported_formats ] diff --git a/demos/speech_web/API.md b/demos/speech_web/API.md index c51446749..f66ec138e 100644 --- a/demos/speech_web/API.md +++ b/demos/speech_web/API.md @@ -401,4 +401,4 @@ curl -X 'GET' \ "code": 0, "result":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", "message": "ok" -``` \ No newline at end of file +``` diff --git a/demos/speech_web/speech_server/main.py b/demos/speech_web/speech_server/main.py index b10176670..8d1de2abe 100644 --- a/demos/speech_web/speech_server/main.py +++ b/demos/speech_web/speech_server/main.py @@ -3,48 +3,53 @@ # 2. 接收录音音频,返回识别结果 # 3. 接收ASR识别结果,返回NLP对话结果 # 4. 接收NLP对话结果,返回TTS音频 - +import argparse import base64 -import yaml -import os -import json import datetime +import json +import os +from typing import List +from typing import Optional + +import aiofiles import librosa -import soundfile as sf import numpy as np -import argparse +import soundfile as sf import uvicorn -import aiofiles -from typing import Optional, List -from pydantic import BaseModel -from fastapi import FastAPI, Header, File, UploadFile, Form, Cookie, WebSocket, WebSocketDisconnect +import yaml +from fastapi import Cookie +from fastapi import FastAPI +from fastapi import File +from fastapi import Form +from fastapi import Header +from fastapi import UploadFile +from fastapi import WebSocket +from fastapi import WebSocketDisconnect from fastapi.responses import StreamingResponse -from starlette.responses import FileResponse -from starlette.middleware.cors import CORSMiddleware -from starlette.requests import Request -from starlette.websockets import WebSocketState as WebSocketState - +from pydantic import BaseModel from src.AudioManeger import AudioMannger -from src.util import * from src.robot import Robot -from src.WebsocketManeger import ConnectionManager from src.SpeechBase.vpr import VPR +from src.util import * +from src.WebsocketManeger import ConnectionManager +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import FileResponse +from starlette.websockets import WebSocketState as WebSocketState from paddlespeech.server.engine.asr.online.python.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.utils.audio_process import float2pcm - # 解析配置 -parser = argparse.ArgumentParser( - prog='PaddleSpeechDemo', add_help=True) +parser = argparse.ArgumentParser(prog='PaddleSpeechDemo', add_help=True) parser.add_argument( - "--port", - action="store", - type=int, - help="port of the app", - default=8010, - required=False) + "--port", + action="store", + type=int, + help="port of the app", + default=8010, + required=False) args = parser.parse_args() port = args.port @@ -60,39 +65,41 @@ ie_model_path = "source/model" UPLOAD_PATH = "source/vpr" WAV_PATH = "source/wav" - -base_sources = [ - UPLOAD_PATH, WAV_PATH -] +base_sources = [UPLOAD_PATH, WAV_PATH] for path in base_sources: os.makedirs(path, exist_ok=True) - # 初始化 app = FastAPI() -chatbot = Robot(asr_config, tts_config, asr_init_path, ie_model_path=ie_model_path) +chatbot = Robot( + asr_config, tts_config, asr_init_path, ie_model_path=ie_model_path) manager = ConnectionManager() aumanager = AudioMannger(chatbot) aumanager.init() -vpr = VPR(db_path, dim = 192, top_k = 5) +vpr = VPR(db_path, dim=192, top_k=5) + # 服务配置 class NlpBase(BaseModel): chat: str + class TtsBase(BaseModel): - text: str + text: str + class Audios: def __init__(self) -> None: self.audios = b"" + audios = Audios() ###################################################################### ########################### ASR 服务 ################################# ##################################################################### + # 接收文件,返回ASR结果 # 上传文件 @app.post("/asr/offline") @@ -101,7 +108,8 @@ async def speech2textOffline(files: List[UploadFile]): asr_res = "" for file in files[:1]: # 生成时间戳 - now_name = "asr_offline_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + now_name = "asr_offline_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" out_file_path = os.path.join(WAV_PATH, now_name) async with aiofiles.open(out_file_path, 'wb') as out_file: content = await file.read() # async read @@ -111,9 +119,10 @@ async def speech2textOffline(files: List[UploadFile]): asr_res = chatbot.speech2text(out_file_path) return SuccessRequest(result=asr_res) # else: - # return ErrorRequest(message="文件不是.wav格式") + # return ErrorRequest(message="文件不是.wav格式") return ErrorRequest(message="上传文件为空") + # 接收文件,同时将wav强制转成16k, int16类型 @app.post("/asr/offlinefile") async def speech2textOfflineFile(files: List[UploadFile]): @@ -121,7 +130,8 @@ async def speech2textOfflineFile(files: List[UploadFile]): asr_res = "" for file in files[:1]: # 生成时间戳 - now_name = "asr_offline_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + now_name = "asr_offline_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" out_file_path = os.path.join(WAV_PATH, now_name) async with aiofiles.open(out_file_path, 'wb') as out_file: content = await file.read() # async read @@ -132,22 +142,18 @@ async def speech2textOfflineFile(files: List[UploadFile]): wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes wav_base64 = base64.b64encode(wav_bytes).decode('utf8') - + # 将文件重新写入 now_name = now_name[:-4] + "_16k" + ".wav" out_file_path = os.path.join(WAV_PATH, now_name) - sf.write(out_file_path,wav,16000) + sf.write(out_file_path, wav, 16000) # 返回ASR识别结果 asr_res = chatbot.speech2text(out_file_path) - response_res = { - "asr_result": asr_res, - "wav_base64": wav_base64 - } + response_res = {"asr_result": asr_res, "wav_base64": wav_base64} return SuccessRequest(result=response_res) - - return ErrorRequest(message="上传文件为空") + return ErrorRequest(message="上传文件为空") # 流式接收测试 @@ -161,15 +167,17 @@ async def speech2textOnlineRecive(files: List[UploadFile]): print(f"audios长度变化: {len(audios.audios)}") return SuccessRequest(message="接收成功") + # 采集环境噪音大小 @app.post("/asr/collectEnv") async def collectEnv(files: List[UploadFile]): - for file in files[:1]: + for file in files[:1]: content = await file.read() # async read # 初始化, wav 前44字节是头部信息 aumanager.compute_env_volume(content[44:]) vad_ = aumanager.vad_threshold - return SuccessRequest(result=vad_,message="采集环境噪音成功") + return SuccessRequest(result=vad_, message="采集环境噪音成功") + # 停止录音 @app.get("/asr/stopRecord") @@ -179,6 +187,7 @@ async def stopRecord(): print("Online录音暂停") return SuccessRequest(message="停止成功") + # 恢复录音 @app.get("/asr/resumeRecord") async def resumeRecord(): @@ -210,7 +219,7 @@ async def websocket_endpoint(websocket: WebSocket): # print(f"用户-{user}-离开") -# Online识别的ASR + # Online识别的ASR @app.websocket('/ws/asr/onlineStream') async def websocket_endpoint(websocket: WebSocket): """PaddleSpeech Online ASR Server api @@ -298,12 +307,14 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass + ###################################################################### ########################### NLP 服务 ################################# ##################################################################### + @app.post("/nlp/chat") -async def chatOffline(nlp_base:NlpBase): +async def chatOffline(nlp_base: NlpBase): chat = nlp_base.chat if not chat: return ErrorRequest(message="传入文本为空") @@ -311,8 +322,9 @@ async def chatOffline(nlp_base:NlpBase): res = chatbot.chat(chat) return SuccessRequest(result=res) + @app.post("/nlp/ie") -async def ieOffline(nlp_base:NlpBase): +async def ieOffline(nlp_base: NlpBase): nlp_text = nlp_base.chat if not nlp_text: return ErrorRequest(message="传入文本为空") @@ -320,17 +332,20 @@ async def ieOffline(nlp_base:NlpBase): res = chatbot.ie(nlp_text) return SuccessRequest(result=res) + ###################################################################### ########################### TTS 服务 ################################# ##################################################################### + @app.post("/tts/offline") -async def text2speechOffline(tts_base:TtsBase): +async def text2speechOffline(tts_base: TtsBase): text = tts_base.text if not text: return ErrorRequest(message="文本为空") else: - now_name = "tts_"+ datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + now_name = "tts_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" out_file_path = os.path.join(WAV_PATH, now_name) # 保存为文件,再转成base64传输 chatbot.text2speech(text, outpath=out_file_path) @@ -339,12 +354,14 @@ async def text2speechOffline(tts_base:TtsBase): base_str = base64.b64encode(data_bin) return SuccessRequest(result=base_str) + # http流式TTS @app.post("/tts/online") async def stream_tts(request_body: TtsBase): text = request_body.text return StreamingResponse(chatbot.text2speechStreamBytes(text=text)) + # ws流式TTS @app.websocket("/ws/tts/online") async def stream_ttsWS(websocket: WebSocket): @@ -356,17 +373,11 @@ async def stream_ttsWS(websocket: WebSocket): if text: for sub_wav in chatbot.text2speechStream(text=text): # print("发送sub wav: ", len(sub_wav)) - res = { - "wav": sub_wav, - "done": False - } + res = {"wav": sub_wav, "done": False} await websocket.send_json(res) - + # 输送结束 - res = { - "wav": sub_wav, - "done": True - } + res = {"wav": sub_wav, "done": True} await websocket.send_json(res) # manager.disconnect(websocket) @@ -396,8 +407,9 @@ async def vpr_enroll(table_name: str=None, return {'status': False, 'msg': "spk_id can not be None"} # Save the upload data to server. content = await audio.read() - now_name = "vpr_enroll_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" - audio_path = os.path.join(UPLOAD_PATH, now_name) + now_name = "vpr_enroll_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + audio_path = os.path.join(UPLOAD_PATH, now_name) with open(audio_path, "wb+") as f: f.write(content) @@ -413,12 +425,13 @@ async def vpr_recog(request: Request, audio: UploadFile=File(...)): # Voice print recognition online # try: - # Save the upload data to server. + # Save the upload data to server. content = await audio.read() - now_name = "vpr_query_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" - query_audio_path = os.path.join(UPLOAD_PATH, now_name) + now_name = "vpr_query_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + query_audio_path = os.path.join(UPLOAD_PATH, now_name) with open(query_audio_path, "wb+") as f: - f.write(content) + f.write(content) spk_ids, paths, scores = vpr.do_search_vpr(query_audio_path) res = dict(zip(spk_ids, zip(paths, scores))) @@ -426,7 +439,9 @@ async def vpr_recog(request: Request, res = sorted(res.items(), key=lambda item: item[1][1], reverse=True) return res # except Exception as e: - # return {'status': False, 'msg': e}, 400 + + +# return {'status': False, 'msg': e}, 400 @app.post('/vpr/del') @@ -460,17 +475,18 @@ async def vpr_database64(vprId: int): return {'status': False, 'msg': "vpr_id can not be None"} audio_path = vpr.do_get_wav(vprId) # 返回base64 - + # 将文件转成16k, 16bit类型的wav文件 wav, sr = librosa.load(audio_path, sr=16000) wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes wav_base64 = base64.b64encode(wav_bytes).decode('utf8') - + return SuccessRequest(result=wav_base64) except Exception as e: return {'status': False, 'msg': e}, 400 + @app.get('/vpr/data') async def vpr_data(vprId: int): # Get the audio file from path by spk_id in MySQL @@ -482,11 +498,6 @@ async def vpr_data(vprId: int): except Exception as e: return {'status': False, 'msg': e}, 400 + if __name__ == '__main__': uvicorn.run(app=app, host='0.0.0.0', port=port) - - - - - - diff --git a/demos/speech_web/speech_server/requirements.txt b/demos/speech_web/speech_server/requirements.txt index 7e7bd1680..607f0d4d0 100644 --- a/demos/speech_web/speech_server/requirements.txt +++ b/demos/speech_web/speech_server/requirements.txt @@ -1,14 +1,13 @@ aiofiles +faiss-cpu fastapi librosa numpy +paddlenlp +paddlepaddle +paddlespeech pydantic -scikit_learn +python-multipartscikit_learn SoundFile starlette uvicorn -paddlepaddle -paddlespeech -paddlenlp -faiss-cpu -python-multipart \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/AudioManeger.py b/demos/speech_web/speech_server/src/AudioManeger.py index 0deb03699..8daa07a32 100644 --- a/demos/speech_web/speech_server/src/AudioManeger.py +++ b/demos/speech_web/speech_server/src/AudioManeger.py @@ -1,15 +1,22 @@ +import datetime import imp -from queue import Queue -import numpy as np import os -import wave import random -import datetime +import wave +from queue import Queue + +import numpy as np + from .util import randName class AudioMannger: - def __init__(self, robot, frame_length=160, frame=10, data_width=2, vad_default = 300): + def __init__(self, + robot, + frame_length=160, + frame=10, + data_width=2, + vad_default=300): # 二进制 pcm 流 self.audios = b'' self.asr_result = "" @@ -20,8 +27,9 @@ class AudioMannger: os.makedirs(self.file_dir, exist_ok=True) self.vad_deafult = vad_default self.vad_threshold = vad_default - self.vad_threshold_path = os.path.join(self.file_dir, "vad_threshold.npy") - + self.vad_threshold_path = os.path.join(self.file_dir, + "vad_threshold.npy") + # 10ms 一帧 self.frame_length = frame_length # 10帧,检测一次 vad @@ -30,67 +38,64 @@ class AudioMannger: self.data_width = data_width # window self.window_length = frame_length * frame * data_width - + # 是否开始录音 self.on_asr = False - self.silence_cnt = 0 + self.silence_cnt = 0 self.max_silence_cnt = 4 self.is_pause = False # 录音暂停与恢复 - - - + def init(self): if os.path.exists(self.vad_threshold_path): # 平均响度文件存在 self.vad_threshold = np.load(self.vad_threshold_path) - - + def clear_audio(self): # 清空 pcm 累积片段与 asr 识别结果 self.audios = b'' - + def clear_asr(self): self.asr_result = "" - - + def compute_chunk_volume(self, start_index, pcm_bins): # 根据帧长计算能量平均值 - pcm_bin = pcm_bins[start_index: start_index + self.window_length] + pcm_bin = pcm_bins[start_index:start_index + self.window_length] # 转成 numpy pcm_np = np.frombuffer(pcm_bin, np.int16) # 归一化 + 计算响度 x = pcm_np.astype(np.float32) x = np.abs(x) - return np.mean(x) - - + return np.mean(x) + def is_speech(self, start_index, pcm_bins): # 检查是否没 if start_index > len(pcm_bins): return False # 检查从这个 start 开始是否为静音帧 - energy = self.compute_chunk_volume(start_index=start_index, pcm_bins=pcm_bins) + energy = self.compute_chunk_volume( + start_index=start_index, pcm_bins=pcm_bins) # print(energy) if energy > self.vad_threshold: return True else: return False - + def compute_env_volume(self, pcm_bins): max_energy = 0 start = 0 while start < len(pcm_bins): - energy = self.compute_chunk_volume(start_index=start, pcm_bins=pcm_bins) + energy = self.compute_chunk_volume( + start_index=start, pcm_bins=pcm_bins) if energy > max_energy: max_energy = energy start += self.window_length self.vad_threshold = max_energy + 100 if max_energy > self.vad_deafult else self.vad_deafult - + # 保存成文件 np.save(self.vad_threshold_path, self.vad_threshold) print(f"vad 阈值大小: {self.vad_threshold}") print(f"环境采样保存: {os.path.realpath(self.vad_threshold_path)}") - + def stream_asr(self, pcm_bin): # 先把 pcm_bin 送进去做端点检测 start = 0 @@ -99,7 +104,7 @@ class AudioMannger: self.on_asr = True self.silence_cnt = 0 print("录音中") - self.audios += pcm_bin[ start : start + self.window_length] + self.audios += pcm_bin[start:start + self.window_length] else: if self.on_asr: self.silence_cnt += 1 @@ -110,41 +115,42 @@ class AudioMannger: print("录音停止") # audios 保存为 wav, 送入 ASR if len(self.audios) > 2 * 16000: - file_path = os.path.join(self.file_dir, "asr_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav") + file_path = os.path.join( + self.file_dir, + "asr_" + datetime.datetime.strftime( + datetime.datetime.now(), + '%Y%m%d%H%M%S') + randName() + ".wav") self.save_audio(file_path=file_path) self.asr_result = self.robot.speech2text(file_path) self.clear_audio() - return self.asr_result + return self.asr_result else: # 正常接收 print("录音中 静音") - self.audios += pcm_bin[ start : start + self.window_length] + self.audios += pcm_bin[start:start + self.window_length] start += self.window_length return "" - + def save_audio(self, file_path): print("保存音频") - wf = wave.open(file_path, 'wb') # 创建一个音频文件,名字为“01.wav" - wf.setnchannels(1) # 设置声道数为2 - wf.setsampwidth(2) # 设置采样深度为 - wf.setframerate(16000) # 设置采样率为16000 + wf = wave.open(file_path, 'wb') # 创建一个音频文件,名字为“01.wav" + wf.setnchannels(1) # 设置声道数为2 + wf.setsampwidth(2) # 设置采样深度为 + wf.setframerate(16000) # 设置采样率为16000 # 将数据写入创建的音频文件 wf.writeframes(self.audios) # 写完后将文件关闭 wf.close() - + def end(self): # audios 保存为 wav, 送入 ASR file_path = os.path.join(self.file_dir, "asr.wav") self.save_audio(file_path=file_path) return self.robot.speech2text(file_path) - + def stop(self): self.is_pause = True self.audios = b'' - + def resume(self): self.is_pause = False - - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/asr.py b/demos/speech_web/speech_server/src/SpeechBase/asr.py index 8d4c0cffc..3b9e05c8b 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/asr.py +++ b/demos/speech_web/speech_server/src/SpeechBase/asr.py @@ -1,13 +1,15 @@ from re import sub + +import librosa import numpy as np import paddle -import librosa import soundfile from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine from paddlespeech.server.engine.asr.online.python.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.utils.config import get_config + def readWave(samples): x_len = len(samples) @@ -31,20 +33,23 @@ def readWave(samples): class ASR: - def __init__(self, config_path, ) -> None: + def __init__( + self, + config_path, ) -> None: self.config = get_config(config_path)['asr_online'] self.engine = ASREngine() self.engine.init(self.config) self.connection_handler = PaddleASRConnectionHanddler(self.engine) - + def offlineASR(self, samples, sample_rate=16000): - x_chunk, x_chunk_lens = self.engine.preprocess(samples=samples, sample_rate=sample_rate) + x_chunk, x_chunk_lens = self.engine.preprocess( + samples=samples, sample_rate=sample_rate) self.engine.run(x_chunk, x_chunk_lens) result = self.engine.postprocess() self.engine.reset() return result - def onlineASR(self, samples:bytes=None, is_finished=False): + def onlineASR(self, samples: bytes=None, is_finished=False): if not is_finished: # 流式开始 self.connection_handler.extract_feat(samples) @@ -58,5 +63,3 @@ class ASR: asr_results = self.connection_handler.get_result() self.connection_handler.reset() return asr_results - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/nlp.py b/demos/speech_web/speech_server/src/SpeechBase/nlp.py index 4ece63256..b642a51d6 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/nlp.py +++ b/demos/speech_web/speech_server/src/SpeechBase/nlp.py @@ -1,23 +1,23 @@ from paddlenlp import Taskflow + class NLP: def __init__(self, ie_model_path=None): schema = ["时间", "出发地", "目的地", "费用"] if ie_model_path: - self.ie_model = Taskflow("information_extraction", - schema=schema, task_path=ie_model_path) + self.ie_model = Taskflow( + "information_extraction", + schema=schema, + task_path=ie_model_path) else: - self.ie_model = Taskflow("information_extraction", - schema=schema) - + self.ie_model = Taskflow("information_extraction", schema=schema) + self.dialogue_model = Taskflow("dialogue") - + def chat(self, text): result = self.dialogue_model([text]) return result[0] - + def ie(self, text): result = self.ie_model(text) return result - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py b/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py index 6937def58..628dbc4e8 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py +++ b/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py @@ -1,18 +1,20 @@ import base64 -import sqlite3 import os +import sqlite3 + import numpy as np from pkg_resources import resource_stream -def dict_factory(cursor, row): - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d +def dict_factory(cursor, row): + d = {} + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d + class DataBase(object): - def __init__(self, db_path:str): + def __init__(self, db_path: str): db_path = os.path.realpath(db_path) if os.path.exists(db_path): @@ -21,12 +23,12 @@ class DataBase(object): db_path_dir = os.path.dirname(db_path) os.makedirs(db_path_dir, exist_ok=True) self.db_path = db_path - + self.conn = sqlite3.connect(self.db_path) self.conn.row_factory = dict_factory self.cursor = self.conn.cursor() self.init_database() - + def init_database(self): """ 初始化数据库, 若表不存在则创建 @@ -41,12 +43,12 @@ class DataBase(object): """ self.cursor.execute(sql) self.conn.commit() - + def execute_base(self, sql, data_dict): self.cursor.execute(sql, data_dict) self.conn.commit() - - def insert_one(self, username, vector_base64:str, wav_path): + + def insert_one(self, username, vector_base64: str, wav_path): if not os.path.exists(wav_path): return None, "wav not exists" else: @@ -55,6 +57,7 @@ class DataBase(object): vprtable (username, vector, wavpath) values (?, ?, ?) """ + try: self.cursor.execute(sql, (username, vector_base64, wav_path)) self.conn.commit() @@ -63,25 +66,27 @@ class DataBase(object): except Exception as e: print(e) return None, e - + def select_all(self): sql = """ SELECT * from vprtable """ result = self.cursor.execute(sql).fetchall() return result - + def select_by_id(self, vpr_id): sql = f""" SELECT * from vprtable WHERE `id` = {vpr_id} """ + result = self.cursor.execute(sql).fetchall() return result - + def select_by_username(self, username): sql = f""" SELECT * from vprtable WHERE `username` = '{username}' """ + result = self.cursor.execute(sql).fetchall() return result @@ -89,28 +94,30 @@ class DataBase(object): sql = f""" DELETE from vprtable WHERE `username`='{username}' """ + self.cursor.execute(sql) self.conn.commit() - + def drop_all(self): sql = f""" DELETE from vprtable """ + self.cursor.execute(sql) self.conn.commit() - + def drop_table(self): sql = f""" DROP TABLE vprtable """ + self.cursor.execute(sql) self.conn.commit() - - def encode_vector(self, vector:np.ndarray): + + def encode_vector(self, vector: np.ndarray): return base64.b64encode(vector).decode('utf8') - + def decode_vector(self, vector_base64, dtype=np.float32): b = base64.b64decode(vector_base64) vc = np.frombuffer(b, dtype=dtype) return vc - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/tts.py b/demos/speech_web/speech_server/src/SpeechBase/tts.py index d5ba0c802..030f8eef0 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/tts.py +++ b/demos/speech_web/speech_server/src/SpeechBase/tts.py @@ -5,18 +5,20 @@ # 2. 加载模型 # 3. 端到端推理 # 4. 流式推理 - import base64 -import math import logging +import math + import numpy as np -from paddlespeech.server.utils.onnx_infer import get_sess -from paddlespeech.t2s.frontend.zh_frontend import Frontend -from paddlespeech.server.utils.util import denorm, get_chunks + +from paddlespeech.server.engine.tts.online.onnx.tts_engine import TTSEngine from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.config import get_config +from paddlespeech.server.utils.onnx_infer import get_sess +from paddlespeech.server.utils.util import denorm +from paddlespeech.server.utils.util import get_chunks +from paddlespeech.t2s.frontend.zh_frontend import Frontend -from paddlespeech.server.engine.tts.online.onnx.tts_engine import TTSEngine class TTS: def __init__(self, config_path): @@ -26,12 +28,12 @@ class TTS: self.engine.init(self.config) self.executor = self.engine.executor #self.engine.warm_up() - + # 前端初始化 self.frontend = Frontend( - phone_vocab_path=self.engine.executor.phones_dict, - tone_vocab_path=None) - + phone_vocab_path=self.engine.executor.phones_dict, + tone_vocab_path=None) + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): """ Streaming inference removes the result of pad inference @@ -48,39 +50,37 @@ class TTS: data = data[front_pad * upsample:(front_pad + block) * upsample] return data - + def offlineTTS(self, text): get_tone_ids = False merge_sentences = False - + input_ids = self.frontend.get_input_ids( - text, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) + text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] wav_list = [] for i in range(len(phone_ids)): orig_hs = self.engine.executor.am_encoder_infer_sess.run( - None, input_feed={'text': phone_ids[i].numpy()} - ) + None, input_feed={'text': phone_ids[i].numpy()}) hs = orig_hs[0] am_decoder_output = self.engine.executor.am_decoder_sess.run( - None, input_feed={'xs': hs}) + None, input_feed={'xs': hs}) am_postnet_output = self.engine.executor.am_postnet_sess.run( - None, - input_feed={ - 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) - }) + None, + input_feed={ + 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) + }) am_output_data = am_decoder_output + np.transpose( am_postnet_output[0], (0, 2, 1)) normalized_mel = am_output_data[0][0] - mel = denorm(normalized_mel, self.engine.executor.am_mu, self.engine.executor.am_std) + mel = denorm(normalized_mel, self.engine.executor.am_mu, + self.engine.executor.am_std) wav = self.engine.executor.voc_sess.run( - output_names=None, input_feed={'logmel': mel})[0] + output_names=None, input_feed={'logmel': mel})[0] wav_list.append(wav) wavs = np.concatenate(wav_list) return wavs - + def streamTTS(self, text): get_tone_ids = False @@ -88,9 +88,7 @@ class TTS: # front input_ids = self.frontend.get_input_ids( - text, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) + text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] for i in range(len(phone_ids)): @@ -105,14 +103,15 @@ class TTS: mel = mel[0] # voc streaming - mel_chunks = get_chunks(mel, self.config.voc_block, self.config.voc_pad, "voc") + mel_chunks = get_chunks(mel, self.config.voc_block, + self.config.voc_pad, "voc") voc_chunk_num = len(mel_chunks) for i, mel_chunk in enumerate(mel_chunks): sub_wav = self.executor.voc_sess.run( output_names=None, input_feed={'logmel': mel_chunk}) - sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i, - self.config.voc_block, self.config.voc_pad, - self.config.voc_upsample) + sub_wav = self.depadding( + sub_wav[0], voc_chunk_num, i, self.config.voc_block, + self.config.voc_pad, self.config.voc_upsample) yield self.after_process(sub_wav) @@ -130,7 +129,8 @@ class TTS: end = min(self.config.voc_block + self.config.voc_pad, mel_len) # streaming am - hss = get_chunks(orig_hs, self.config.am_block, self.config.am_pad, "am") + hss = get_chunks(orig_hs, self.config.am_block, + self.config.am_pad, "am") am_chunk_num = len(hss) for i, hs in enumerate(hss): am_decoder_output = self.executor.am_decoder_sess.run( @@ -147,7 +147,8 @@ class TTS: sub_mel = denorm(normalized_mel, self.executor.am_mu, self.executor.am_std) sub_mel = self.depadding(sub_mel, am_chunk_num, i, - self.config.am_block, self.config.am_pad, 1) + self.config.am_block, + self.config.am_pad, 1) if i == 0: mel_streaming = sub_mel @@ -165,23 +166,22 @@ class TTS: output_names=None, input_feed={'logmel': voc_chunk}) sub_wav = self.depadding( sub_wav[0], voc_chunk_num, voc_chunk_id, - self.config.voc_block, self.config.voc_pad, self.config.voc_upsample) + self.config.voc_block, self.config.voc_pad, + self.config.voc_upsample) yield self.after_process(sub_wav) voc_chunk_id += 1 - start = max( - 0, voc_chunk_id * self.config.voc_block - self.config.voc_pad) - end = min( - (voc_chunk_id + 1) * self.config.voc_block + self.config.voc_pad, - mel_len) + start = max(0, voc_chunk_id * self.config.voc_block - + self.config.voc_pad) + end = min((voc_chunk_id + 1) * self.config.voc_block + + self.config.voc_pad, mel_len) else: logging.error( "Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts." - ) + ) - def streamTTSBytes(self, text): for wav in self.engine.executor.infer( text=text, @@ -191,19 +191,14 @@ class TTS: wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes yield wav_bytes - - + def after_process(self, wav): # for tvm wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64 return wav_base64 - + def streamTTS_TVM(self, text): # 用 TVM 优化 pass - - - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/vpr.py b/demos/speech_web/speech_server/src/SpeechBase/vpr.py index 29ee986e3..46e9c0366 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/vpr.py +++ b/demos/speech_web/speech_server/src/SpeechBase/vpr.py @@ -1,11 +1,14 @@ # vpr Demo 没有使用 mysql 与 muilvs, 仅用于docker演示 import logging + import faiss -from matplotlib import use import numpy as np +from matplotlib import use + from .sql_helper import DataBase from .vpr_encode import get_audio_embedding + class VPR: def __init__(self, db_path, dim, top_k) -> None: # 初始化 @@ -14,15 +17,15 @@ class VPR: self.top_k = top_k self.dtype = np.float32 self.vpr_idx = 0 - + # db 初始化 self.db = DataBase(db_path) - + # faiss 初始化 index_ip = faiss.IndexFlatIP(dim) self.index_ip = faiss.IndexIDMap(index_ip) self.init() - + def init(self): # demo 初始化,把 mysql中的向量注册到 faiss 中 sql_dbs = self.db.select_all() @@ -34,12 +37,13 @@ class VPR: if len(vc.shape) == 1: vc = np.expand_dims(vc, axis=0) # 构建数据库 - self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64')) + self.index_ip.add_with_ids(vc, np.array( + (idx, )).astype('int64')) logging.info("faiss 构建完毕") - + def faiss_enroll(self, idx, vc): - self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64')) - + self.index_ip.add_with_ids(vc, np.array((idx, )).astype('int64')) + def vpr_enroll(self, username, wav_path): # 注册声纹 emb = get_audio_embedding(wav_path) @@ -53,21 +57,22 @@ class VPR: else: last_idx, mess = None return last_idx - + def vpr_recog(self, wav_path): # 识别声纹 emb_search = get_audio_embedding(wav_path) - + if emb_search is not None: emb_search = np.expand_dims(emb_search, axis=0) D, I = self.index_ip.search(emb_search, self.top_k) D = D.tolist()[0] - I = I.tolist()[0] - return [(round(D[i] * 100, 2 ), I[i]) for i in range(len(D)) if I[i] != -1] + I = I.tolist()[0] + return [(round(D[i] * 100, 2), I[i]) for i in range(len(D)) + if I[i] != -1] else: logging.error("识别失败") return None - + def do_search_vpr(self, wav_path): spk_ids, paths, scores = [], [], [] recog_result = self.vpr_recog(wav_path) @@ -78,41 +83,39 @@ class VPR: scores.append(score) paths.append("") return spk_ids, paths, scores - + def vpr_del(self, username): # 根据用户username, 删除声纹 # 查用户ID,删除对应向量 res = self.db.select_by_username(username) for r in res: idx = r['id'] - self.index_ip.remove_ids(np.array((idx,)).astype('int64')) - + self.index_ip.remove_ids(np.array((idx, )).astype('int64')) + self.db.drop_by_username(username) - + def vpr_list(self): # 获取数据列表 return self.db.select_all() - + def do_list(self): spk_ids, vpr_ids = [], [] for res in self.db.select_all(): spk_ids.append(res['username']) vpr_ids.append(res['id']) - return spk_ids, vpr_ids - + return spk_ids, vpr_ids + def do_get_wav(self, vpr_idx): - res = self.db.select_by_id(vpr_idx) - return res[0]['wavpath'] - - + res = self.db.select_by_id(vpr_idx) + return res[0]['wavpath'] + def vpr_data(self, idx): # 获取对应ID的数据 res = self.db.select_by_id(idx) return res - + def vpr_droptable(self): # 删除表 self.db.drop_table() # 清空 faiss self.index_ip.reset() - diff --git a/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py b/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py index a6a00e4d0..9d052fd98 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py +++ b/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py @@ -1,9 +1,12 @@ -from paddlespeech.cli.vector import VectorExecutor -import numpy as np import logging +import numpy as np + +from paddlespeech.cli.vector import VectorExecutor + vector_executor = VectorExecutor() + def get_audio_embedding(path): """ Use vpr_inference to generate embedding of audio @@ -16,5 +19,3 @@ def get_audio_embedding(path): except Exception as e: logging.error(f"Error with embedding:{e}") return None - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/WebsocketManeger.py b/demos/speech_web/speech_server/src/WebsocketManeger.py index 5edde8430..954d849a5 100644 --- a/demos/speech_web/speech_server/src/WebsocketManeger.py +++ b/demos/speech_web/speech_server/src/WebsocketManeger.py @@ -2,6 +2,7 @@ from typing import List from fastapi import WebSocket + class ConnectionManager: def __init__(self): # 存放激活的ws连接对象 @@ -28,4 +29,4 @@ class ConnectionManager: await connection.send_text(message) -manager = ConnectionManager() \ No newline at end of file +manager = ConnectionManager() diff --git a/demos/speech_web/speech_server/src/robot.py b/demos/speech_web/speech_server/src/robot.py index b971c57b5..031d91eb2 100644 --- a/demos/speech_web/speech_server/src/robot.py +++ b/demos/speech_web/speech_server/src/robot.py @@ -1,60 +1,65 @@ -from paddlespeech.cli.asr.infer import ASRExecutor -import soundfile as sf import os -import librosa +import librosa +import soundfile as sf from src.SpeechBase.asr import ASR -from src.SpeechBase.tts import TTS from src.SpeechBase.nlp import NLP +from src.SpeechBase.tts import TTS + +from paddlespeech.cli.asr.infer import ASRExecutor class Robot: - def __init__(self, asr_config, tts_config,asr_init_path, + def __init__(self, + asr_config, + tts_config, + asr_init_path, ie_model_path=None) -> None: self.nlp = NLP(ie_model_path=ie_model_path) self.asr = ASR(config_path=asr_config) self.tts = TTS(config_path=tts_config) self.tts_sample_rate = 24000 self.asr_sample_rate = 16000 - + # 流式识别效果不如端到端的模型,这里流式模型与端到端模型分开 self.asr_model = ASRExecutor() self.asr_name = "conformer_wenetspeech" self.warm_up_asrmodel(asr_init_path) - - def warm_up_asrmodel(self, asr_init_path): + def warm_up_asrmodel(self, asr_init_path): if not os.path.exists(asr_init_path): path_dir = os.path.dirname(asr_init_path) if not os.path.exists(path_dir): os.makedirs(path_dir, exist_ok=True) - + # TTS生成,采样率24000 text = "生成初始音频" self.text2speech(text, asr_init_path) - + # asr model初始化 - self.asr_model(asr_init_path, model=self.asr_name,lang='zh', - sample_rate=16000, force_yes=True) - - + self.asr_model( + asr_init_path, + model=self.asr_name, + lang='zh', + sample_rate=16000, + force_yes=True) + def speech2text(self, audio_file): self.asr_model.preprocess(self.asr_name, audio_file) self.asr_model.infer(self.asr_name) res = self.asr_model.postprocess() return res - + def text2speech(self, text, outpath): wav = self.tts.offlineTTS(text) - sf.write( - outpath, wav, samplerate=self.tts_sample_rate) + sf.write(outpath, wav, samplerate=self.tts_sample_rate) res = wav return res - + def text2speechStream(self, text): for sub_wav_base64 in self.tts.streamTTS(text=text): yield sub_wav_base64 - + def text2speechStreamBytes(self, text): for wav_bytes in self.tts.streamTTSBytes(text=text): yield wav_bytes @@ -66,5 +71,3 @@ class Robot: def ie(self, text): result = self.nlp.ie(text) return result - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/util.py b/demos/speech_web/speech_server/src/util.py index 34005d919..4a566b6ee 100644 --- a/demos/speech_web/speech_server/src/util.py +++ b/demos/speech_web/speech_server/src/util.py @@ -1,18 +1,13 @@ import random + def randName(n=5): - return "".join(random.sample('zyxwvutsrqponmlkjihgfedcba',n)) + return "".join(random.sample('zyxwvutsrqponmlkjihgfedcba', n)) + def SuccessRequest(result=None, message="ok"): - return { - "code": 0, - "result":result, - "message": message - } + return {"code": 0, "result": result, "message": message} + def ErrorRequest(result=None, message="error"): - return { - "code": -1, - "result":result, - "message": message - } \ No newline at end of file + return {"code": -1, "result": result, "message": message} diff --git a/demos/streaming_asr_server/local/rtf_from_log.py b/demos/streaming_asr_server/local/rtf_from_log.py index 4b89b48fd..09a9c9750 100755 --- a/demos/streaming_asr_server/local/rtf_from_log.py +++ b/demos/streaming_asr_server/local/rtf_from_log.py @@ -34,7 +34,7 @@ if __name__ == '__main__': n = 0 for m in rtfs: # not accurate, may have duplicate log - n += 1 + n += 1 T += m['T'] P += m['P'] diff --git a/docs/requirements.txt b/docs/requirements.txt index bf1486c5e..a77da9f82 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,12 +1,6 @@ -myst-parser -numpydoc -recommonmark>=0.5.0 -sphinx -sphinx-autobuild -sphinx-markdown-tables -sphinx_rtd_theme -paddlepaddle>=2.2.2 +braceexpandcolorlog editdistance +fastapi g2p_en g2pM h5py @@ -14,39 +8,44 @@ inflect jieba jsonlines kaldiio +keyboard librosa==0.8.1 loguru matplotlib +myst-parser nara_wpe +numpydoc onnxruntime -pandas paddlenlp +paddlepaddle>=2.2.2 paddlespeech_feat +pandas +pathos == 0.2.8 +pattern_singleton Pillow>=9.0.0 praatio==5.0.0 +prettytable pypinyin pypinyin-dict python-dateutil pyworld==0.2.12 +recommonmark>=0.5.0 resampy==0.2.2 sacrebleu scipy sentencepiece~=0.1.96 soundfile~=0.10 +sphinx +sphinx-autobuild +sphinx-markdown-tables +sphinx_rtd_theme textgrid timer tqdm typeguard +uvicorn visualdl webrtcvad +websockets yacs~=0.1.8 -prettytable zhon -colorlog -pathos == 0.2.8 -fastapi -websockets -keyboard -uvicorn -pattern_singleton -braceexpand \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index c94cf0b86..cd9b1807b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -20,10 +20,11 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. +import os +import sys + import recommonmark.parser import sphinx_rtd_theme -import sys -import os sys.path.insert(0, os.path.abspath('../..')) autodoc_mock_imports = ["soundfile", "librosa"] diff --git a/examples/csmsc/tts1/README.md b/examples/csmsc/tts1/README.md index 0725eda2d..a5b48e6b4 100644 --- a/examples/csmsc/tts1/README.md +++ b/examples/csmsc/tts1/README.md @@ -113,9 +113,9 @@ optional arguments: --transformer-tts-stat TRANSFORMER_TTS_STAT mean and standard deviation used to normalize spectrogram when training transformer tts. - --waveflow-config WAVEFLOW_CONFIG + --voc-config WAVEFLOW_CONFIG waveflow config file. - --waveflow-checkpoint WAVEFLOW_CHECKPOINT + --voc-checkpoint WAVEFLOW_CHECKPOINT waveflow checkpoint to load. --phones-dict PHONES_DICT phone vocabulary file. @@ -150,9 +150,9 @@ optional arguments: --transformer-tts-stat TRANSFORMER_TTS_STAT mean and standard deviation used to normalize spectrogram when training transformer tts. - --waveflow-config WAVEFLOW_CONFIG + --voc-config WAVEFLOW_CONFIG waveflow config file. - --waveflow-checkpoint WAVEFLOW_CHECKPOINT + --voc-ckpt WAVEFLOW_CHECKPOINT waveflow checkpoint to load. --phones-dict PHONES_DICT phone vocabulary file. @@ -170,14 +170,14 @@ optional arguments: ## Pretrained Model Pretrained Model can be downloaded here: -- [transformer_tts_csmsc_ckpt.zip](https://pan.baidu.com/s/1jan_ZXCGKI7DHvS2jxWIEw?pwd=9i0t) +- [transformer_tts_csmsc_ckpt.zip](https://pan.baidu.com/s/1-6uvjQDxS0-6c9XZPBYqBQ?pwd=jjc3) TransformerTTS checkpoint contains files listed below. ```text transformer_tts_csmsc_ckpt ├── default.yaml # default config used to train transformer_tts ├── phone_id_map.txt # phone vocabulary file when training transformer_tts -├── snapshot_iter_1118250.pdz # model parameters and optimizer states +├── snapshot_iter_675000.pdz # model parameters and optimizer states └── speech_stats.npy # statistics used to normalize spectrogram when training transformer_tts ``` You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained transformer_tts and waveflow models. @@ -190,8 +190,8 @@ python3 ${BIN_DIR}/synthesize_e2e.py \ --transformer-tts-config=transformer_tts_csmsc_ckpt/default.yaml \ --transformer-tts-checkpoint=transformer_tts_csmsc_ckpt/snapshot_iter_1118250.pdz \ --transformer-tts-stat=transformer_tts_csmsc_ckpt/speech_stats.npy \ - --waveflow-config=waveflow_ljspeech_ckpt_0.3/config.yaml \ - --waveflow-checkpoint=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \ + --voc-config=waveflow_ljspeech_ckpt_0.3/config.yaml \ + --voc-ckpt=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \ --text=${BIN_DIR}/../sentences.txt \ --output-dir=exp/default/test_e2e \ --phones-dict=transformer_tts_csmsc_ckpt/phone_id_map.txt diff --git a/examples/csmsc/tts1/conf/default.yaml b/examples/csmsc/tts1/conf/default.yaml index 456b6a1e3..d4a62b836 100644 --- a/examples/csmsc/tts1/conf/default.yaml +++ b/examples/csmsc/tts1/conf/default.yaml @@ -1,17 +1,16 @@ - -fs : 22050 # Hz, sample rate -n_fft : 1024 # FFT size (samples). -win_length : 1024 # Window length (samples). 46.4ms -n_shift : 256 # Hop size (samples). 11.6ms -fmin : 0 # Hz, min frequency when converting to mel -fmax : 8000 # Hz, max frequency when converting to mel +fs : 24000 # Hz, sample rate +n_fft : 2048 # FFT size (samples). +win_length : 1200 # Window length (samples). 46.4ms +n_shift : 300 # Hop size (samples). 11.6ms +fmin : 80 # Hz, min frequency when converting to mel +fmax : 7600 # Hz, max frequency when converting to mel n_mels : 80 # mel bands window: "hann" # Window function. ########################################################### # DATA SETTING # ########################################################### -batch_size: 16 +batch_size: 4 num_workers: 2 ########################################################## @@ -82,11 +81,11 @@ optimizer: ########################################################### # TRAINING SETTING # ########################################################### -max_epoch: 500 +max_epoch: 300 num_snapshots: 5 ########################################################### # OTHER SETTING # ########################################################### -seed: 10086 \ No newline at end of file +seed: 10086 diff --git a/examples/csmsc/tts1/local/preprocess.sh b/examples/csmsc/tts1/local/preprocess.sh index f92664cef..ac5e8ec58 100644 --- a/examples/csmsc/tts1/local/preprocess.sh +++ b/examples/csmsc/tts1/local/preprocess.sh @@ -5,15 +5,27 @@ stop_stage=100 config_path=$1 +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./baker_alignment_tone \ + --output=durations.txt \ + --config=${config_path} +fi + + if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then # extract features echo "Extract features ..." - python3 ${BIN_DIR}/preprocess.py \ - --dataset=ljspeech \ + python3 ${BIN_DIR}/preprocess_new.py \ + --dataset=baker\ --rootdir=~/datasets/BZNSYP/ \ --dumpdir=dump \ - --config-path=conf/default.yaml \ - --num-cpu=8 + --dur-file=durations.txt + --config-path=${config_path} \ + --num-cpu=8 \ + --cut-sil=True fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/examples/csmsc/tts1/local/synthesize.sh b/examples/csmsc/tts1/local/synthesize.sh index 9d1c47b39..8c14f647d 100644 --- a/examples/csmsc/tts1/local/synthesize.sh +++ b/examples/csmsc/tts1/local/synthesize.sh @@ -3,15 +3,93 @@ config_path=$1 train_output_path=$2 ckpt_name=$3 +stage=0 +stop_stage=0 -FLAGS_allocator_strategy=naive_best_fit \ -FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/synthesize.py \ - --transformer-tts-config=${config_path} \ - --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ - --transformer-tts-stat=dump/train/speech_stats.npy \ - --waveflow-config=waveflow_ljspeech_ckpt_0.3/config.yaml \ - --waveflow-checkpoint=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \ - --test-metadata=dump/test/norm/metadata.jsonl \ - --output-dir=${train_output_path}/test \ - --phones-dict=dump/phone_id_map.txt +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=pwgan_csmsc \ + --voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \ + --voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \ + --voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# for more GAN Vocoders +# multi band melgan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=mb_melgan_csmsc \ + --voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\ + --voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# style melgan +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=style_melgan_csmsc \ + --voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# hifigan +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "in hifigan syn" + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=hifigan_csmsc \ + --voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# wavernn +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "in wavernn syn" + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=wavernn_csmsc \ + --voc_config=wavernn_csmsc_ckpt_0.2.0/default.yaml \ + --voc_ckpt=wavernn_csmsc_ckpt_0.2.0/snapshot_iter_400000.pdz \ + --voc_stat=wavernn_csmsc_ckpt_0.2.0/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi \ No newline at end of file diff --git a/examples/csmsc/tts1/local/synthesize_e2e.sh b/examples/csmsc/tts1/local/synthesize_e2e.sh index 25a862f90..ce2f30afd 100644 --- a/examples/csmsc/tts1/local/synthesize_e2e.sh +++ b/examples/csmsc/tts1/local/synthesize_e2e.sh @@ -4,14 +4,106 @@ config_path=$1 train_output_path=$2 ckpt_name=$3 -FLAGS_allocator_strategy=naive_best_fit \ -FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/synthesize_e2e.py \ - --transformer-tts-config=${config_path} \ - --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ - --transformer-tts-stat=dump/train/speech_stats.npy \ - --waveflow-config=waveflow_ljspeech_ckpt_0.3/config.yaml \ - --waveflow-checkpoint=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \ - --text=${BIN_DIR}/../sentences_en.txt \ - --output-dir=${train_output_path}/test_e2e \ - --phones-dict=dump/phone_id_map.txt +stage=0 +stop_stage=0 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=pwgan_csmsc \ + --voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \ + --voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \ + --voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + #--inference_dir=${train_output_path}/inference + +fi + +# for more GAN Vocoders +# multi band melgan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=mb_melgan_csmsc \ + --voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\ + --voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + #--inference_dir=${train_output_path}/inference +fi + +# the pretrained models haven't release now +# style melgan +# style melgan's Dygraph to Static Graph is not ready now +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=style_melgan_csmsc \ + --voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt + # --inference_dir=${train_output_path}/inference +fi + +# hifigan +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "in hifigan syn_e2e" + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=hifigan_csmsc \ + --voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \ + --voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + #--inference_dir=${train_output_path}/inference +fi + +# wavernn +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "in wavernn syn_e2e" + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/../synthesize_e2e.py \ + --transformer-tts-config=${config_path} \ + --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --transformer-tts-stat=dump/train/speech_stats.npy \ + --voc=wavernn_csmsc \ + --voc_config=wavernn_csmsc_ckpt_0.2.0/default.yaml \ + --voc_ckpt=wavernn_csmsc_ckpt_0.2.0/snapshot_iter_400000.pdz \ + --voc_stat=wavernn_csmsc_ckpt_0.2.0/feats_stats.npy \ + --lang=zh \ + --text=${BIN_DIR}/../sentences.txt \ + --output_dir=${train_output_path}/test_e2e \ + --phones_dict=dump/phone_id_map.txt \ + #--inference_dir=${train_output_path}/inference +fi diff --git a/examples/ernie_sat/local/inference.py b/examples/ernie_sat/local/inference.py index e6a0788fd..cb0aba5c1 100644 --- a/examples/ernie_sat/local/inference.py +++ b/examples/ernie_sat/local/inference.py @@ -26,15 +26,15 @@ from align import words2phns from align import words2phns_zh from paddle import nn from sedit_arg_parser import parse_args + +from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn +from paddlespeech.t2s.models.ernie_sat.mlm import build_model_from_file from utils import eval_durs from utils import get_voc_out from utils import is_chinese from utils import load_num_sequence_text from utils import read_2col_text -from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn -from paddlespeech.t2s.models.ernie_sat.mlm import build_model_from_file - random.seed(0) np.random.seed(0) diff --git a/examples/ernie_sat/local/inference_new.py b/examples/ernie_sat/local/inference_new.py index 525967eb1..53a46298b 100644 --- a/examples/ernie_sat/local/inference_new.py +++ b/examples/ernie_sat/local/inference_new.py @@ -27,15 +27,15 @@ from align import words2phns from align import words2phns_zh from paddle import nn from sedit_arg_parser import parse_args +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn +from paddlespeech.t2s.models.ernie_sat.ernie_sat import ErnieSAT from utils import eval_durs from utils import get_voc_out from utils import is_chinese from utils import load_num_sequence_text from utils import read_2col_text -from yacs.config import CfgNode - -from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn -from paddlespeech.t2s.models.ernie_sat.ernie_sat import ErnieSAT random.seed(0) np.random.seed(0) diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py index 4b1c0ef3d..b781c4a8e 100644 --- a/paddlespeech/__init__.py +++ b/paddlespeech/__init__.py @@ -14,5 +14,3 @@ import _locale _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) - - diff --git a/paddlespeech/audio/__init__.py b/paddlespeech/audio/__init__.py index 83be8e32e..a91958105 100644 --- a/paddlespeech/audio/__init__.py +++ b/paddlespeech/audio/__init__.py @@ -14,12 +14,12 @@ from . import compliance from . import datasets from . import features -from . import text -from . import transform -from . import streamdata from . import functional from . import io from . import metric from . import sox_effects +from . import streamdata +from . import text +from . import transform from .backends import load from .backends import save diff --git a/paddlespeech/audio/streamdata/__init__.py b/paddlespeech/audio/streamdata/__init__.py index 753fcc11b..47a2e79b3 100644 --- a/paddlespeech/audio/streamdata/__init__.py +++ b/paddlespeech/audio/streamdata/__init__.py @@ -4,67 +4,66 @@ # Modified from https://github.com/webdataset/webdataset # # flake8: noqa - -from .cache import ( - cached_tarfile_samples, - cached_tarfile_to_samples, - lru_cleanup, - pipe_cleaner, -) -from .compat import WebDataset, WebLoader, FluidWrapper -from .extradatasets import MockDataset, with_epoch, with_length -from .filters import ( - associate, - batched, - decode, - detshuffle, - extract_keys, - getfirst, - info, - map, - map_dict, - map_tuple, - pipelinefilter, - rename, - rename_keys, - audio_resample, - select, - shuffle, - slice, - to_tuple, - transform_with, - unbatched, - xdecode, - audio_data_filter, - audio_tokenize, - audio_resample, - audio_compute_fbank, - audio_spec_aug, - sort, - audio_padding, - audio_cmvn, - placeholder, -) -from .handlers import ( - ignore_and_continue, - ignore_and_stop, - reraise_exception, - warn_and_continue, - warn_and_stop, -) +from .cache import cached_tarfile_samples +from .cache import cached_tarfile_to_samples +from .cache import lru_cleanup +from .cache import pipe_cleaner +from .compat import FluidWrapper +from .compat import WebDataset +from .compat import WebLoader +from .extradatasets import MockDataset +from .extradatasets import with_epoch +from .extradatasets import with_length +from .filters import associate +from .filters import audio_cmvn +from .filters import audio_compute_fbank +from .filters import audio_data_filter +from .filters import audio_padding +from .filters import audio_resample +from .filters import audio_spec_aug +from .filters import audio_tokenize +from .filters import batched +from .filters import decode +from .filters import detshuffle +from .filters import extract_keys +from .filters import getfirst +from .filters import info +from .filters import map +from .filters import map_dict +from .filters import map_tuple +from .filters import pipelinefilter +from .filters import placeholder +from .filters import rename +from .filters import rename_keys +from .filters import select +from .filters import shuffle +from .filters import slice +from .filters import sort +from .filters import to_tuple +from .filters import transform_with +from .filters import unbatched +from .filters import xdecode +from .handlers import ignore_and_continue +from .handlers import ignore_and_stop +from .handlers import reraise_exception +from .handlers import warn_and_continue +from .handlers import warn_and_stop +from .mix import RandomMix +from .mix import RoundRobin from .pipeline import DataPipeline -from .shardlists import ( - MultiShardSample, - ResampledShards, - SimpleShardList, - non_empty, - resampled, - shardspec, - single_node_only, - split_by_node, - split_by_worker, -) -from .tariterators import tarfile_samples, tarfile_to_samples -from .utils import PipelineStage, repeatedly -from .writer import ShardWriter, TarWriter, numpy_dumps -from .mix import RandomMix, RoundRobin +from .shardlists import MultiShardSample +from .shardlists import non_empty +from .shardlists import resampled +from .shardlists import ResampledShards +from .shardlists import shardspec +from .shardlists import SimpleShardList +from .shardlists import single_node_only +from .shardlists import split_by_node +from .shardlists import split_by_worker +from .tariterators import tarfile_samples +from .tariterators import tarfile_to_samples +from .utils import PipelineStage +from .utils import repeatedly +from .writer import numpy_dumps +from .writer import ShardWriter +from .writer import TarWriter diff --git a/paddlespeech/audio/streamdata/autodecode.py b/paddlespeech/audio/streamdata/autodecode.py index ca0e2ea2f..d7f7937bd 100644 --- a/paddlespeech/audio/streamdata/autodecode.py +++ b/paddlespeech/audio/streamdata/autodecode.py @@ -5,18 +5,19 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - """Automatically decode webdataset samples.""" - -import io, json, os, pickle, re, tempfile +import io +import json +import os +import pickle +import re +import tempfile from functools import partial import numpy as np - """Extensions passed on to the image decoder.""" image_extensions = "jpg jpeg png ppm pgm pbm pnm".split() - ################################################################ # handle basic datatypes ################################################################ @@ -128,7 +129,7 @@ def call_extension_handler(key, data, f, extensions): target = target.split(".") if len(target) > len(extension): continue - if extension[-len(target) :] == target: + if extension[-len(target):] == target: return f(data) return None @@ -268,7 +269,6 @@ def imagehandler(imagespec, extensions=image_extensions): ################################################################ # torch video ################################################################ - ''' def torch_video(key, data): """Decode video using the torchvideo library. @@ -289,7 +289,6 @@ def torch_video(key, data): return torchvision.io.read_video(fname, pts_unit="sec") ''' - ################################################################ # paddlespeech.audio ################################################################ @@ -359,7 +358,6 @@ def gzfilter(key, data): # decode entire training amples ################################################################ - default_pre_handlers = [gzfilter] default_post_handlers = [basichandlers] @@ -387,7 +385,8 @@ class Decoder: pre = default_pre_handlers if post is None: post = default_post_handlers - assert all(callable(h) for h in handlers), f"one of {handlers} not callable" + assert all(callable(h) + for h in handlers), f"one of {handlers} not callable" assert all(callable(h) for h in pre), f"one of {pre} not callable" assert all(callable(h) for h in post), f"one of {post} not callable" self.handlers = pre + handlers + post diff --git a/paddlespeech/audio/streamdata/cache.py b/paddlespeech/audio/streamdata/cache.py index e7bbffa1b..5cd94aa6c 100644 --- a/paddlespeech/audio/streamdata/cache.py +++ b/paddlespeech/audio/streamdata/cache.py @@ -2,7 +2,11 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset -import itertools, os, random, re, sys +import itertools +import os +import random +import re +import sys from urllib.parse import urlparse from . import filters @@ -40,7 +44,7 @@ def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False): os.remove(fname) -def download(url, dest, chunk_size=1024 ** 2, verbose=False): +def download(url, dest, chunk_size=1024**2, verbose=False): """Download a file from `url` to `dest`.""" temp = dest + f".temp{os.getpid()}" with gopen.gopen(url) as stream: @@ -65,12 +69,11 @@ def pipe_cleaner(spec): def get_file_cached( - spec, - cache_size=-1, - cache_dir=None, - url_to_name=pipe_cleaner, - verbose=False, -): + spec, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + verbose=False, ): if cache_size == -1: cache_size = default_cache_size if cache_dir is None: @@ -107,15 +110,14 @@ verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0")) def cached_url_opener( - data, - handler=reraise_exception, - cache_size=-1, - cache_dir=None, - url_to_name=pipe_cleaner, - validator=check_tar_format, - verbose=False, - always=False, -): + data, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + validator=check_tar_format, + verbose=False, + always=False, ): """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" verbose = verbose or verbose_cache for sample in data: @@ -132,8 +134,7 @@ def cached_url_opener( cache_size=cache_size, cache_dir=cache_dir, url_to_name=url_to_name, - verbose=verbose, - ) + verbose=verbose, ) if verbose: print("# opening %s" % dest, file=sys.stderr) assert os.path.exists(dest) @@ -143,9 +144,8 @@ def cached_url_opener( data = f.read(200) os.remove(dest) raise ValueError( - "%s (%s) is not a tar archive, but a %s, contains %s" - % (dest, url, ftype, repr(data)) - ) + "%s (%s) is not a tar archive, but a %s, contains %s" % + (dest, url, ftype, repr(data))) try: stream = open(dest, "rb") sample.update(stream=stream) @@ -158,7 +158,7 @@ def cached_url_opener( continue raise exn except Exception as exn: - exn.args = exn.args + (url,) + exn.args = exn.args + (url, ) if handler(exn): continue else: @@ -166,14 +166,13 @@ def cached_url_opener( def cached_tarfile_samples( - src, - handler=reraise_exception, - cache_size=-1, - cache_dir=None, - verbose=False, - url_to_name=pipe_cleaner, - always=False, -): + src, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + verbose=False, + url_to_name=pipe_cleaner, + always=False, ): streams = cached_url_opener( src, handler=handler, @@ -181,8 +180,7 @@ def cached_tarfile_samples( cache_dir=cache_dir, verbose=verbose, url_to_name=url_to_name, - always=always, - ) + always=always, ) samples = tar_file_and_group_expander(streams, handler=handler) return samples diff --git a/paddlespeech/audio/streamdata/compat.py b/paddlespeech/audio/streamdata/compat.py index deda53384..84f46abff 100644 --- a/paddlespeech/audio/streamdata/compat.py +++ b/paddlespeech/audio/streamdata/compat.py @@ -6,13 +6,18 @@ from dataclasses import dataclass from itertools import islice from typing import List -import braceexpand, yaml +import braceexpand +import yaml from . import autodecode -from . import cache, filters, shardlists, tariterators +from . import cache +from . import filters +from . import shardlists +from . import tariterators from .filters import reraise_exception +from .paddle_utils import DataLoader +from .paddle_utils import IterableDataset from .pipeline import DataPipeline -from .paddle_utils import DataLoader, IterableDataset class FluidInterface: @@ -26,7 +31,8 @@ class FluidInterface: return self.compose(filters.unbatched()) def listed(self, batchsize, partial=True): - return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None) + return self.compose( + filters.batched(), batchsize=batchsize, collation_fn=None) def unlisted(self): return self.compose(filters.unlisted()) @@ -43,9 +49,19 @@ class FluidInterface: def map(self, f, handler=reraise_exception): return self.compose(filters.map(f, handler=handler)) - def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception): - handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args] - decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial) + def decode(self, + *args, + pre=None, + post=None, + only=None, + partial=False, + handler=reraise_exception): + handlers = [ + autodecode.ImageHandler(x) if isinstance(x, str) else x + for x in args + ] + decoder = autodecode.Decoder( + handlers, pre=pre, post=post, only=only, partial=partial) return self.map(decoder, handler=handler) def map_dict(self, handler=reraise_exception, **kw): @@ -80,12 +96,12 @@ class FluidInterface: def audio_data_filter(self, *args, **kw): return self.compose(filters.audio_data_filter(*args, **kw)) - + def audio_tokenize(self, *args, **kw): return self.compose(filters.audio_tokenize(*args, **kw)) def resample(self, *args, **kw): - return self.compose(filters.resample(*args, **kw)) + return self.compose(filters.resample(*args, **kw)) def audio_compute_fbank(self, *args, **kw): return self.compose(filters.audio_compute_fbank(*args, **kw)) @@ -102,27 +118,28 @@ class FluidInterface: def audio_cmvn(self, cmvn_file): return self.compose(filters.audio_cmvn(cmvn_file)) + class WebDataset(DataPipeline, FluidInterface): """Small fluid-interface wrapper for DataPipeline.""" def __init__( - self, - urls, - handler=reraise_exception, - resampled=False, - repeat=False, - shardshuffle=None, - cache_size=0, - cache_dir=None, - detshuffle=False, - nodesplitter=shardlists.single_node_only, - verbose=False, - ): + self, + urls, + handler=reraise_exception, + resampled=False, + repeat=False, + shardshuffle=None, + cache_size=0, + cache_dir=None, + detshuffle=False, + nodesplitter=shardlists.single_node_only, + verbose=False, ): super().__init__() if isinstance(urls, IterableDataset): assert not resampled self.append(urls) - elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")): + elif isinstance(urls, str) and (urls.endswith(".yaml") or + urls.endswith(".yml")): with (open(urls)) as stream: spec = yaml.safe_load(stream) assert "datasets" in spec @@ -152,9 +169,7 @@ class WebDataset(DataPipeline, FluidInterface): handler=handler, verbose=verbose, cache_size=cache_size, - cache_dir=cache_dir, - ) - ) + cache_dir=cache_dir, )) class FluidWrapper(DataPipeline, FluidInterface): diff --git a/paddlespeech/audio/streamdata/extradatasets.py b/paddlespeech/audio/streamdata/extradatasets.py index e6d617724..0933908a3 100644 --- a/paddlespeech/audio/streamdata/extradatasets.py +++ b/paddlespeech/audio/streamdata/extradatasets.py @@ -5,13 +5,10 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - - """Train PyTorch models directly from POSIX tar archive. Code works locally or over HTTP connections. """ - import itertools as itt import os import random @@ -63,8 +60,7 @@ class repeatedly(IterableDataset, PipelineStage): return utils.repeatedly( source, nepochs=self.nepochs, - nbatches=self.nbatches, - ) + nbatches=self.nbatches, ) class with_epoch(IterableDataset): diff --git a/paddlespeech/audio/streamdata/filters.py b/paddlespeech/audio/streamdata/filters.py index 82b9c6bab..4056203fb 100644 --- a/paddlespeech/audio/streamdata/filters.py +++ b/paddlespeech/audio/streamdata/filters.py @@ -3,7 +3,6 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - # Modified from https://github.com/webdataset/webdataset # Modified from wenet(https://github.com/wenet-e2e/wenet) """A collection of iterators for data transformations. @@ -12,28 +11,32 @@ These functions are plain iterator functions. You can find curried versions in webdataset.filters, and you can find IterableDataset wrappers in webdataset.processing. """ - import io -from fnmatch import fnmatch +import itertools +import os +import random import re -import itertools, os, random, sys, time -from functools import reduce, wraps +import sys +import time +from fnmatch import fnmatch +from functools import reduce +from functools import wraps import numpy as np +import paddle from . import autodecode -from . import utils -from .paddle_utils import PaddleTensor -from .utils import PipelineStage - +from . import utils from .. import backends from ..compliance import kaldi -import paddle from ..transform.cmvn import GlobalCMVN -from ..utils.tensor_utils import pad_sequence -from ..transform.spec_augment import time_warp -from ..transform.spec_augment import time_mask from ..transform.spec_augment import freq_mask +from ..transform.spec_augment import time_mask +from ..transform.spec_augment import time_warp +from ..utils.tensor_utils import pad_sequence +from .paddle_utils import PaddleTensor +from .utils import PipelineStage + class FilterFunction(object): """Helper class for currying pipeline stages. @@ -159,10 +162,12 @@ def transform_with(sample, transformers): result[i] = f(sample[i]) return result + ### # Iterators ### + def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""): """Print information about the samples that are passing through. @@ -325,15 +330,24 @@ def _rename(data, handler=reraise_exception, keep=True, **kw): for sample in data: try: if not keep: - yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()} + yield { + k: getfirst(sample, v, missing_is_error=True) + for k, v in kw.items() + } else: def listify(v): return v.split(";") if isinstance(v, str) else v to_be_replaced = {x for v in kw.values() for x in listify(v)} - result = {k: v for k, v in sample.items() if k not in to_be_replaced} - result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}) + result = { + k: v + for k, v in sample.items() if k not in to_be_replaced + } + result.update({ + k: getfirst(sample, v, missing_is_error=True) + for k, v in kw.items() + }) yield result except Exception as exn: if handler(exn): @@ -381,7 +395,11 @@ def _map_dict(data, handler=reraise_exception, **kw): map_dict = pipelinefilter(_map_dict) -def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None): +def _to_tuple(data, + *args, + handler=reraise_exception, + missing_is_error=True, + none_is_error=None): """Convert dict samples to tuples.""" if none_is_error is None: none_is_error = missing_is_error @@ -390,7 +408,10 @@ def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, non for sample in data: try: - result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args]) + result = tuple([ + getfirst(sample, f, missing_is_error=missing_is_error) + for f in args + ]) if none_is_error and any(x is None for x in result): raise ValueError(f"to_tuple {args} got {sample.keys()}") yield result @@ -463,19 +484,28 @@ rsample = pipelinefilter(_rsample) slice = pipelinefilter(itertools.islice) -def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False): +def _extract_keys(source, + *patterns, + duplicate_is_error=True, + ignore_missing=False): for sample in source: result = [] for pattern in patterns: - pattern = pattern.split(";") if isinstance(pattern, str) else pattern - matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)] + pattern = pattern.split(";") if isinstance(pattern, + str) else pattern + matches = [ + x for x in sample.keys() + if any(fnmatch("." + x, p) for p in pattern) + ] if len(matches) == 0: if ignore_missing: continue else: - raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.") + raise ValueError( + f"Cannot find {pattern} in sample keys {sample.keys()}.") if len(matches) > 1 and duplicate_is_error: - raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.") + raise ValueError( + f"Multiple sample keys {sample.keys()} match {pattern}.") value = sample[matches[0]] result.append(value) yield tuple(result) @@ -484,7 +514,12 @@ def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=Fal extract_keys = pipelinefilter(_extract_keys) -def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw): +def _rename_keys(source, + *args, + keep_unselected=False, + must_match=True, + duplicate_is_error=True, + **kw): renamings = [(pattern, output) for output, pattern in args] renamings += [(pattern, output) for output, pattern in kw.items()] for sample in source: @@ -504,11 +539,15 @@ def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicat continue if new_name in new_sample: if duplicate_is_error: - raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.") + raise ValueError( + f"Duplicate value in sample {sample.keys()} after rename." + ) continue new_sample[new_name] = value if must_match and not all(matched.values()): - raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).") + raise ValueError( + f"Not all patterns ({matched}) matched sample keys ({sample.keys()})." + ) yield new_sample @@ -541,18 +580,18 @@ def find_decoder(decoders, path): if fname.startswith("__"): return lambda x: x for pattern, fun in decoders[::-1]: - if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern): + if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), + pattern): return fun return None def _xdecode( - source, - *args, - must_decode=True, - defaults=default_decoders, - **kw, -): + source, + *args, + must_decode=True, + defaults=default_decoders, + **kw, ): decoders = list(defaults) + list(args) decoders += [("*." + k, v) for k, v in kw.items()] for sample in source: @@ -575,18 +614,18 @@ def _xdecode( new_sample[path] = value yield new_sample -xdecode = pipelinefilter(_xdecode) +xdecode = pipelinefilter(_xdecode) def _audio_data_filter(source, - frame_shift=10, - max_length=10240, - min_length=10, - token_max_length=200, - token_min_length=1, - min_output_input_ratio=0.0005, - max_output_input_ratio=1): + frame_shift=10, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): """ Filter sample according to feature and label length Inplace operation. @@ -613,7 +652,8 @@ def _audio_data_filter(source, assert 'wav' in sample assert 'label' in sample # sample['wav'] is paddle.Tensor, we have 100 frames every second (default) - num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift) + num_frames = sample['wav'].shape[1] / sample['sample_rate'] * ( + 1000 / frame_shift) if num_frames < min_length: continue if num_frames > max_length: @@ -629,13 +669,15 @@ def _audio_data_filter(source, continue yield sample + audio_data_filter = pipelinefilter(_audio_data_filter) + def _audio_tokenize(source, - symbol_table, - bpe_model=None, - non_lang_syms=None, - split_with_space=False): + symbol_table, + bpe_model=None, + non_lang_syms=None, + split_with_space=False): """ Decode text to chars or BPE Inplace operation @@ -693,8 +735,10 @@ def _audio_tokenize(source, sample['label'] = label yield sample + audio_tokenize = pipelinefilter(_audio_tokenize) + def _audio_resample(source, resample_rate=16000): """ Resample data. Inplace operation. @@ -713,18 +757,22 @@ def _audio_resample(source, resample_rate=16000): waveform = sample['wav'] if sample_rate != resample_rate: sample['sample_rate'] = resample_rate - sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample( - waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate - )) + sample['wav'] = paddle.to_tensor( + backends.soundfile_backend.resample( + waveform.numpy(), + src_sr=sample_rate, + target_sr=resample_rate)) yield sample + audio_resample = pipelinefilter(_audio_resample) + def _audio_compute_fbank(source, - num_mel_bins=80, - frame_length=25, - frame_shift=10, - dither=0.0): + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0): """ Extract fbank Args: @@ -746,30 +794,33 @@ def _audio_compute_fbank(source, waveform = sample['wav'] waveform = waveform * (1 << 15) # Only keep fname, feat, label - mat = kaldi.fbank(waveform, - n_mels=num_mel_bins, - frame_length=frame_length, - frame_shift=frame_shift, - dither=dither, - energy_floor=0.0, - sr=sample_rate) + mat = kaldi.fbank( + waveform, + n_mels=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sr=sample_rate) yield dict(fname=sample['fname'], label=sample['label'], feat=mat) audio_compute_fbank = pipelinefilter(_audio_compute_fbank) -def _audio_spec_aug(source, - max_w=5, - w_inplace=True, - w_mode="PIL", - max_f=30, - num_f_mask=2, - f_inplace=True, - f_replace_with_zero=False, - max_t=40, - num_t_mask=2, - t_inplace=True, - t_replace_with_zero=False,): + +def _audio_spec_aug( + source, + max_w=5, + w_inplace=True, + w_mode="PIL", + max_f=30, + num_f_mask=2, + f_inplace=True, + f_replace_with_zero=False, + max_t=40, + num_t_mask=2, + t_inplace=True, + t_replace_with_zero=False, ): """ Do spec augmentation Inplace operation @@ -793,12 +844,23 @@ def _audio_spec_aug(source, for sample in source: x = sample['feat'] x = x.numpy() - x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode) - x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero) - x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero) + x = time_warp(x, max_time_warp=max_w, inplace=w_inplace, mode=w_mode) + x = freq_mask( + x, + F=max_f, + n_mask=num_f_mask, + inplace=f_inplace, + replace_with_zero=f_replace_with_zero) + x = time_mask( + x, + T=max_t, + n_mask=num_t_mask, + inplace=t_inplace, + replace_with_zero=t_replace_with_zero) sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) yield sample + audio_spec_aug = pipelinefilter(_audio_spec_aug) @@ -829,8 +891,10 @@ def _sort(source, sort_size=500): for x in buf: yield x + sort = pipelinefilter(_sort) + def _batched(source, batch_size=16): """ Static batch the data by `batch_size` @@ -850,8 +914,10 @@ def _batched(source, batch_size=16): if len(buf) > 0: yield buf + batched = pipelinefilter(_batched) + def dynamic_batched(source, max_frames_in_batch=12000): """ Dynamic batch the data until the total frames in batch reach `max_frames_in_batch` @@ -892,8 +958,8 @@ def _audio_padding(source): """ for sample in source: assert isinstance(sample, list) - feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample], - dtype="int64") + feats_length = paddle.to_tensor( + [x['feat'].shape[0] for x in sample], dtype="int64") order = paddle.argsort(feats_length, descending=True) feats_lengths = paddle.to_tensor( [sample[i]['feat'].shape[0] for i in order], dtype="int64") @@ -902,20 +968,20 @@ def _audio_padding(source): sorted_labels = [ paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order ] - label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels], - dtype="int64") - padded_feats = pad_sequence(sorted_feats, - batch_first=True, - padding_value=0) - padding_labels = pad_sequence(sorted_labels, - batch_first=True, - padding_value=-1) - - yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths = paddle.to_tensor( + [x.shape[0] for x in sorted_labels], dtype="int64") + padded_feats = pad_sequence( + sorted_feats, batch_first=True, padding_value=0) + padding_labels = pad_sequence( + sorted_labels, batch_first=True, padding_value=-1) + + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths) + audio_padding = pipelinefilter(_audio_padding) + def _audio_cmvn(source, cmvn_file): global_cmvn = GlobalCMVN(cmvn_file) for batch in source: @@ -923,13 +989,16 @@ def _audio_cmvn(source, cmvn_file): padded_feats = padded_feats.numpy() padded_feats = global_cmvn(padded_feats) padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32) - yield (sorted_keys, padded_feats, feats_lengths, padding_labels, - label_lengths) + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths) + audio_cmvn = pipelinefilter(_audio_cmvn) + def _placeholder(source): for data in source: yield data + placeholder = pipelinefilter(_placeholder) diff --git a/paddlespeech/audio/streamdata/gopen.py b/paddlespeech/audio/streamdata/gopen.py index 457d048a6..60a434603 100644 --- a/paddlespeech/audio/streamdata/gopen.py +++ b/paddlespeech/audio/streamdata/gopen.py @@ -3,12 +3,12 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - - """Open URLs by calling subcommands.""" - -import os, sys, re -from subprocess import PIPE, Popen +import os +import re +import sys +from subprocess import PIPE +from subprocess import Popen from urllib.parse import urlparse # global used for printing additional node information during verbose output @@ -31,14 +31,13 @@ class Pipe: """ def __init__( - self, - *args, - mode=None, - timeout=7200.0, - ignore_errors=False, - ignore_status=[], - **kw, - ): + self, + *args, + mode=None, + timeout=7200.0, + ignore_errors=False, + ignore_status=[], + **kw, ): """Create an IO Pipe.""" self.ignore_errors = ignore_errors self.ignore_status = [0] + ignore_status @@ -75,8 +74,7 @@ class Pipe: if verbose: print( f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}", - file=sys.stderr, - ) + file=sys.stderr, ) if self.status not in self.ignore_status and not self.ignore_errors: raise Exception(f"{self.args}: exit {self.status} (read) {info}") @@ -114,9 +112,11 @@ class Pipe: self.close() -def set_options( - obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None -): +def set_options(obj, + timeout=None, + ignore_errors=None, + ignore_status=None, + handler=None): """Set options for Pipes. This function can be called on any stream. It will set pipe options only @@ -168,16 +168,14 @@ def gopen_pipe(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141], - ) # skipcq: BAN-B604 + ignore_status=[141], ) # skipcq: BAN-B604 elif mode[0] == "w": return Pipe( cmd, mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141], - ) # skipcq: BAN-B604 + ignore_status=[141], ) # skipcq: BAN-B604 else: raise ValueError(f"{mode}: unknown mode") @@ -196,8 +194,7 @@ def gopen_curl(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 23], - ) # skipcq: BAN-B604 + ignore_status=[141, 23], ) # skipcq: BAN-B604 elif mode[0] == "w": cmd = f"curl -s -L -T - '{url}'" return Pipe( @@ -205,8 +202,7 @@ def gopen_curl(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 26], - ) # skipcq: BAN-B604 + ignore_status=[141, 26], ) # skipcq: BAN-B604 else: raise ValueError(f"{mode}: unknown mode") @@ -226,15 +222,13 @@ def gopen_htgs(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 23], - ) # skipcq: BAN-B604 + ignore_status=[141, 23], ) # skipcq: BAN-B604 elif mode[0] == "w": raise ValueError(f"{mode}: cannot write") else: raise ValueError(f"{mode}: unknown mode") - def gopen_gsutil(url, mode="rb", bufsize=8192): """Open a URL with `curl`. @@ -249,8 +243,7 @@ def gopen_gsutil(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 23], - ) # skipcq: BAN-B604 + ignore_status=[141, 23], ) # skipcq: BAN-B604 elif mode[0] == "w": cmd = f"gsutil cp - '{url}'" return Pipe( @@ -258,13 +251,11 @@ def gopen_gsutil(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 26], - ) # skipcq: BAN-B604 + ignore_status=[141, 26], ) # skipcq: BAN-B604 else: raise ValueError(f"{mode}: unknown mode") - def gopen_error(url, *args, **kw): """Raise a value error. @@ -285,8 +276,7 @@ gopen_schemes = dict( ftps=gopen_curl, scp=gopen_curl, gs=gopen_gsutil, - htgs=gopen_htgs, -) + htgs=gopen_htgs, ) def gopen(url, mode="rb", bufsize=8192, **kw): diff --git a/paddlespeech/audio/streamdata/handlers.py b/paddlespeech/audio/streamdata/handlers.py index 7f3d28b62..0173e5373 100644 --- a/paddlespeech/audio/streamdata/handlers.py +++ b/paddlespeech/audio/streamdata/handlers.py @@ -3,7 +3,6 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - """Pluggable exception handlers. These are functions that take an exception as an argument and then return... @@ -14,8 +13,8 @@ These are functions that take an exception as an argument and then return... They are used as handler= arguments in much of the library. """ - -import time, warnings +import time +import warnings def reraise_exception(exn): diff --git a/paddlespeech/audio/streamdata/mix.py b/paddlespeech/audio/streamdata/mix.py index 7d790f00f..8ab01a1c8 100644 --- a/paddlespeech/audio/streamdata/mix.py +++ b/paddlespeech/audio/streamdata/mix.py @@ -5,16 +5,21 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - """Classes for mixing samples from multiple sources.""" - -import itertools, os, random, time, sys -from functools import reduce, wraps +import itertools +import os +import random +import sys +import time +from functools import reduce +from functools import wraps import numpy as np -from . import autodecode, utils -from .paddle_utils import PaddleTensor, IterableDataset +from . import autodecode +from . import utils +from .paddle_utils import IterableDataset +from .paddle_utils import PaddleTensor from .utils import PipelineStage diff --git a/paddlespeech/audio/streamdata/paddle_utils.py b/paddlespeech/audio/streamdata/paddle_utils.py index 02bc4c841..2368ac675 100644 --- a/paddlespeech/audio/streamdata/paddle_utils.py +++ b/paddlespeech/audio/streamdata/paddle_utils.py @@ -5,10 +5,8 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - """Mock implementations of paddle interfaces when paddle is not available.""" - try: from paddle.io import DataLoader, IterableDataset except ModuleNotFoundError: @@ -23,6 +21,7 @@ except ModuleNotFoundError: pass + try: from paddle import Tensor as PaddleTensor except ModuleNotFoundError: diff --git a/paddlespeech/audio/streamdata/pipeline.py b/paddlespeech/audio/streamdata/pipeline.py index 7339a762a..53e5dc246 100644 --- a/paddlespeech/audio/streamdata/pipeline.py +++ b/paddlespeech/audio/streamdata/pipeline.py @@ -3,15 +3,21 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset #%% -import copy, os, random, sys, time +import copy +import os +import random +import sys +import time from dataclasses import dataclass from itertools import islice from typing import List -import braceexpand, yaml +import braceexpand +import yaml from .handlers import reraise_exception -from .paddle_utils import DataLoader, IterableDataset +from .paddle_utils import DataLoader +from .paddle_utils import IterableDataset from .utils import PipelineStage @@ -22,8 +28,7 @@ def add_length_method(obj): Combined = type( obj.__class__.__name__ + "_Length", (obj.__class__, IterableDataset), - {"__len__": length}, - ) + {"__len__": length}, ) obj.__class__ = Combined return obj diff --git a/paddlespeech/audio/streamdata/shardlists.py b/paddlespeech/audio/streamdata/shardlists.py index cfaf9a64b..5b6c64351 100644 --- a/paddlespeech/audio/streamdata/shardlists.py +++ b/paddlespeech/audio/streamdata/shardlists.py @@ -4,28 +4,30 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - # Modified from https://github.com/webdataset/webdataset - """Train PyTorch models directly from POSIX tar archive. Code works locally or over HTTP connections. """ - -import os, random, sys, time -from dataclasses import dataclass, field +import os +import random +import sys +import time +from dataclasses import dataclass +from dataclasses import field from itertools import islice from typing import List -import braceexpand, yaml +import braceexpand +import yaml from . import utils +from ..utils.log import Logger from .filters import pipelinefilter from .paddle_utils import IterableDataset +logger = Logger(__name__) -from ..utils.log import Logger -logger = Logger(__name__) def expand_urls(urls): if isinstance(urls, str): urllist = urls.split("::") @@ -64,7 +66,8 @@ class SimpleShardList(IterableDataset): def split_by_node(src, group=None): - rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + rank, world_size, worker, num_workers = utils.paddle_worker_info( + group=group) logger.info(f"world_size:{world_size}, rank:{rank}") if world_size > 1: for s in islice(src, rank, None, world_size): @@ -75,9 +78,11 @@ def split_by_node(src, group=None): def single_node_only(src, group=None): - rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + rank, world_size, worker, num_workers = utils.paddle_worker_info( + group=group) if world_size > 1: - raise ValueError("input pipeline needs to be reconfigured for multinode training") + raise ValueError( + "input pipeline needs to be reconfigured for multinode training") for s in src: yield s @@ -104,7 +109,8 @@ def resampled_(src, n=sys.maxsize): rng = random.Random(seed) print("# resampled loading", file=sys.stderr) items = list(src) - print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr) + print( + f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr) for i in range(n): yield rng.choice(items) @@ -118,7 +124,9 @@ def non_empty(src): yield s count += 1 if count == 0: - raise ValueError("pipeline stage received no data at all and this was declared as an error") + raise ValueError( + "pipeline stage received no data at all and this was declared as an error" + ) @dataclass @@ -142,6 +150,8 @@ class MultiShardSample(IterableDataset): def __init__(self, fname): """Construct a shardlist from multiple sources using a YAML spec.""" self.epoch = -1 + + class MultiShardSample(IterableDataset): def __init__(self, fname): """Construct a shardlist from multiple sources using a YAML spec.""" @@ -156,20 +166,23 @@ class MultiShardSample(IterableDataset): else: with open(fname) as stream: spec = yaml.safe_load(stream) - assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys()) + assert set(spec.keys()).issubset( + set("prefix datasets buckets".split())), list(spec.keys()) prefix = expand(spec.get("prefix", "")) self.sources = [] for ds in spec["datasets"]: - assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list( - ds.keys() - ) + assert set(ds.keys()).issubset( + set("buckets name shards resample choose".split())), list( + ds.keys()) buckets = ds.get("buckets", spec.get("buckets", [])) if isinstance(buckets, str): buckets = [buckets] buckets = [expand(s) for s in buckets] if buckets == []: buckets = [""] - assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented" + assert len( + buckets + ) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented" bucket = buckets[0] name = ds.get("name", "@" + bucket) urls = ds["shards"] @@ -177,15 +190,19 @@ class MultiShardSample(IterableDataset): urls = [urls] # urls = [u for url in urls for u in braceexpand.braceexpand(url)] urls = [ - prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url)) + prefix + os.path.join(bucket, u) + for url in urls for u in braceexpand.braceexpand(expand(url)) ] resample = ds.get("resample", -1) nsample = ds.get("choose", -1) if nsample > len(urls): - raise ValueError(f"perepoch {nsample} must be no greater than the number of shards") + raise ValueError( + f"perepoch {nsample} must be no greater than the number of shards" + ) if (nsample > 0) and (resample > 0): raise ValueError("specify only one of perepoch or choose") - entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample) + entry = MSSource( + name=name, urls=urls, perepoch=nsample, resample=resample) self.sources.append(entry) print(f"# {name} {len(urls)} {nsample}", file=sys.stderr) @@ -203,7 +220,7 @@ class MultiShardSample(IterableDataset): # sample without replacement l = list(source.urls) self.rng.shuffle(l) - l = l[: source.perepoch] + l = l[:source.perepoch] else: l = list(source.urls) result += l @@ -227,12 +244,11 @@ class ResampledShards(IterableDataset): """An iterable dataset yielding a list of urls.""" def __init__( - self, - urls, - nshards=sys.maxsize, - worker_seed=None, - deterministic=False, - ): + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, ): """Sample shards from the shard list with replacement. :param urls: a list of URLs as a Python list or brace notation string @@ -252,7 +268,8 @@ class ResampledShards(IterableDataset): if self.deterministic: seed = utils.make_seed(self.worker_seed(), self.epoch) else: - seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4)) + seed = utils.make_seed(self.worker_seed(), self.epoch, + os.getpid(), time.time_ns(), os.urandom(4)) if os.environ.get("WDS_SHOW_SEED", "0") == "1": print(f"# ResampledShards seed {seed}") self.rng = random.Random(seed) diff --git a/paddlespeech/audio/streamdata/tariterators.py b/paddlespeech/audio/streamdata/tariterators.py index b1616918c..79b81c0ce 100644 --- a/paddlespeech/audio/streamdata/tariterators.py +++ b/paddlespeech/audio/streamdata/tariterators.py @@ -3,13 +3,12 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). - # Modified from https://github.com/webdataset/webdataset # Modified from wenet(https://github.com/wenet-e2e/wenet) - """Low level iteration functions for tar archives.""" - -import random, re, tarfile +import random +import re +import tarfile import braceexpand @@ -27,6 +26,7 @@ import numpy as np AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + def base_plus_ext(path): """Split off all file extensions. @@ -47,12 +47,8 @@ def valid_sample(sample): :param sample: sample to be checked """ - return ( - sample is not None - and isinstance(sample, dict) - and len(list(sample.keys())) > 0 - and not sample.get("__bad__", False) - ) + return (sample is not None and isinstance(sample, dict) and + len(list(sample.keys())) > 0 and not sample.get("__bad__", False)) # FIXME: UNUSED @@ -79,16 +75,16 @@ def url_opener(data, handler=reraise_exception, **kw): sample.update(stream=stream) yield sample except Exception as exn: - exn.args = exn.args + (url,) + exn.args = exn.args + (url, ) if handler(exn): continue else: break -def tar_file_iterator( - fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception -): +def tar_file_iterator(fileobj, + skip_meta=r"__[^/]*__($|/)", + handler=reraise_exception): """Iterate over tar file, yielding filename, content pairs for the given tar stream. :param fileobj: byte stream suitable for tarfile @@ -103,11 +99,8 @@ def tar_file_iterator( continue if fname is None: continue - if ( - "/" not in fname - and fname.startswith(meta_prefix) - and fname.endswith(meta_suffix) - ): + if ("/" not in fname and fname.startswith(meta_prefix) and + fname.endswith(meta_suffix)): # skipping metadata for now continue if skip_meta is not None and re.match(skip_meta, fname): @@ -118,8 +111,10 @@ def tar_file_iterator( assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if postfix == 'wav': - waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False) - result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate) + waveform, sample_rate = paddlespeech.audio.load( + stream.extractfile(tarinfo), normal=False) + result = dict( + fname=prefix, wav=waveform, sample_rate=sample_rate) else: txt = stream.extractfile(tarinfo).read().decode('utf8').strip() result = dict(fname=prefix, txt=txt) @@ -128,16 +123,17 @@ def tar_file_iterator( stream.members = [] except Exception as exn: if hasattr(exn, "args") and len(exn.args) > 0: - exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + exn.args = (exn.args[0] + " @ " + str(fileobj), ) + exn.args[1:] if handler(exn): continue else: break del stream -def tar_file_and_group_iterator( - fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception -): + +def tar_file_and_group_iterator(fileobj, + skip_meta=r"__[^/]*__($|/)", + handler=reraise_exception): """ Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix @@ -167,8 +163,11 @@ def tar_file_and_group_iterator( if postfix == 'txt': example['txt'] = file_obj.read().decode('utf8').strip() elif postfix in AUDIO_FORMAT_SETS: - waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False) - waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32) + waveform, sample_rate = paddlespeech.audio.load( + file_obj, normal=False) + waveform = paddle.to_tensor( + np.expand_dims(np.array(waveform), 0), + dtype=paddle.float32) example['wav'] = waveform example['sample_rate'] = sample_rate @@ -176,19 +175,21 @@ def tar_file_and_group_iterator( example[postfix] = file_obj.read() except Exception as exn: if hasattr(exn, "args") and len(exn.args) > 0: - exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + exn.args = (exn.args[0] + " @ " + str(fileobj), + ) + exn.args[1:] if handler(exn): continue else: break valid = False - # logging.warning('error to parse {}'.format(name)) + # logging.warning('error to parse {}'.format(name)) prev_prefix = prefix if prev_prefix is not None: example['fname'] = prev_prefix yield example stream.close() + def tar_file_expander(data, handler=reraise_exception): """Expand a stream of open tar files into a stream of tar file contents. @@ -200,9 +201,8 @@ def tar_file_expander(data, handler=reraise_exception): assert isinstance(source, dict) assert "stream" in source for sample in tar_file_iterator(source["stream"]): - assert ( - isinstance(sample, dict) and "data" in sample and "fname" in sample - ) + assert (isinstance(sample, dict) and "data" in sample and + "fname" in sample) sample["__url__"] = url yield sample except Exception as exn: @@ -213,8 +213,6 @@ def tar_file_expander(data, handler=reraise_exception): break - - def tar_file_and_group_expander(data, handler=reraise_exception): """Expand a stream of open tar files into a stream of tar file contents. @@ -226,9 +224,8 @@ def tar_file_and_group_expander(data, handler=reraise_exception): assert isinstance(source, dict) assert "stream" in source for sample in tar_file_and_group_iterator(source["stream"]): - assert ( - isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample - ) + assert (isinstance(sample, dict) and "wav" in sample and + "txt" in sample and "fname" in sample) sample["__url__"] = url yield sample except Exception as exn: @@ -239,7 +236,11 @@ def tar_file_and_group_expander(data, handler=reraise_exception): break -def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): +def group_by_keys(data, + keys=base_plus_ext, + lcase=True, + suffixes=None, + handler=None): """Return function over iterator that groups key, value pairs into samples. :param keys: function that splits the key into key and extension (base_plus_ext) @@ -254,8 +255,8 @@ def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=N print( prefix, suffix, - current_sample.keys() if isinstance(current_sample, dict) else None, - ) + current_sample.keys() + if isinstance(current_sample, dict) else None, ) if prefix is None: continue if lcase: diff --git a/paddlespeech/audio/streamdata/utils.py b/paddlespeech/audio/streamdata/utils.py index c7294f2bf..74eea6a03 100644 --- a/paddlespeech/audio/streamdata/utils.py +++ b/paddlespeech/audio/streamdata/utils.py @@ -4,22 +4,24 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - # Modified from https://github.com/webdataset/webdataset - """Miscellaneous utility functions.""" - import importlib import itertools as itt import os import re import sys -from typing import Any, Callable, Iterator, Optional, Union +from typing import Any +from typing import Callable +from typing import Iterator +from typing import Optional +from typing import Union from ..utils.log import Logger logger = Logger(__name__) + def make_seed(*args): seed = 0 for arg in args: @@ -37,7 +39,7 @@ def identity(x: Any) -> Any: return x -def safe_eval(s: str, expr: str = "{}"): +def safe_eval(s: str, expr: str="{}"): """Evaluate the given expression more safely.""" if re.sub("[^A-Za-z0-9_]", "", s) != s: raise ValueError(f"safe_eval: illegal characters in: '{s}'") @@ -54,9 +56,9 @@ def lookup_sym(sym: str, modules: list): return None -def repeatedly0( - loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize -): +def repeatedly0(loader: Iterator, + nepochs: int=sys.maxsize, + nbatches: int=sys.maxsize): """Repeatedly returns batches from a DataLoader.""" for epoch in range(nepochs): for sample in itt.islice(loader, nbatches): @@ -69,12 +71,11 @@ def guess_batchsize(batch: Union[tuple, list]): def repeatedly( - source: Iterator, - nepochs: int = None, - nbatches: int = None, - nsamples: int = None, - batchsize: Callable[..., int] = guess_batchsize, -): + source: Iterator, + nepochs: int=None, + nbatches: int=None, + nsamples: int=None, + batchsize: Callable[..., int]=guess_batchsize, ): """Repeatedly yield samples from an iterator.""" epoch = 0 batch = 0 @@ -93,6 +94,7 @@ def repeatedly( if nepochs is not None and epoch >= nepochs: return + def paddle_worker_info(group=None): """Return node and worker info for PyTorch and some distributed environments.""" rank = 0 @@ -126,6 +128,7 @@ def paddle_worker_info(group=None): return rank, world_size, worker, num_workers + def paddle_worker_seed(group=None): """Compute a distinct, deterministic RNG seed for each worker and node.""" rank, world_size, worker, num_workers = paddle_worker_info(group=group) diff --git a/paddlespeech/audio/streamdata/writer.py b/paddlespeech/audio/streamdata/writer.py index 7d4f7703b..3928a3ba6 100644 --- a/paddlespeech/audio/streamdata/writer.py +++ b/paddlespeech/audio/streamdata/writer.py @@ -5,18 +5,24 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - """Classes and functions for writing tar files and WebDataset files.""" - -import io, json, pickle, re, tarfile, time -from typing import Any, Callable, Optional, Union +import io +import json +import pickle +import re +import tarfile +import time +from typing import Any +from typing import Callable +from typing import Optional +from typing import Union import numpy as np from . import gopen -def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622 +def imageencoder(image: Any, format: str="PNG"): # skipcq: PYL-W0622 """Compress an image using PIL and return it as a string. Can handle float or uint8 images. @@ -67,6 +73,7 @@ def bytestr(data: Any): return data.encode("ascii") return str(data).encode("ascii") + def paddle_dumps(data: Any): """Dump data into a bytestring using paddle.dumps. @@ -82,6 +89,7 @@ def paddle_dumps(data: Any): paddle.save(data, stream) return stream.getvalue() + def numpy_dumps(data: np.ndarray): """Dump data into a bytestring using numpy npy format. @@ -139,9 +147,8 @@ def add_handlers(d, keys, value): def make_handlers(): """Create a list of handlers for encoding data.""" handlers = {} - add_handlers( - handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii") - ) + add_handlers(handlers, "cls cls2 class count index inx id", + lambda x: str(x).encode("ascii")) add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8")) add_handlers(handlers, "html htm", lambda x: x.encode("utf-8")) add_handlers(handlers, "pyd pickle", pickle.dumps) @@ -152,7 +159,8 @@ def make_handlers(): add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8")) add_handlers(handlers, "mp msgpack msg", mp_dumps) add_handlers(handlers, "cbor", cbor_dumps) - add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg")) + add_handlers(handlers, "jpg jpeg img image", + lambda data: imageencoder(data, "jpg")) add_handlers(handlers, "png", lambda data: imageencoder(data, "png")) add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm")) add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm")) @@ -192,7 +200,8 @@ def encode_based_on_extension(sample: dict, handlers: dict): :param handlers: handlers for encoding """ return { - k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items()) + k: encode_based_on_extension1(v, k, handlers) + for k, v in list(sample.items()) } @@ -258,15 +267,14 @@ class TarWriter: """ def __init__( - self, - fileobj, - user: str = "bigdata", - group: str = "bigdata", - mode: int = 0o0444, - compress: Optional[bool] = None, - encoder: Union[None, bool, Callable] = True, - keep_meta: bool = False, - ): + self, + fileobj, + user: str="bigdata", + group: str="bigdata", + mode: int=0o0444, + compress: Optional[bool]=None, + encoder: Union[None, bool, Callable]=True, + keep_meta: bool=False, ): """Create a tar writer. :param fileobj: stream to write data to @@ -330,8 +338,7 @@ class TarWriter: continue if not isinstance(v, (bytes, bytearray, memoryview)): raise ValueError( - f"{k} doesn't map to a bytes after encoding ({type(v)})" - ) + f"{k} doesn't map to a bytes after encoding ({type(v)})") key = obj["__key__"] for k in sorted(obj.keys()): if k == "__key__": @@ -349,7 +356,8 @@ class TarWriter: ti.uname = self.user ti.gname = self.group if not isinstance(v, (bytes, bytearray, memoryview)): - raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}") + raise ValueError( + f"converter didn't yield bytes: {k}, {type(v)}") stream = io.BytesIO(v) self.tarstream.addfile(ti, stream) total += ti.size @@ -360,14 +368,13 @@ class ShardWriter: """Like TarWriter but splits into multiple shards.""" def __init__( - self, - pattern: str, - maxcount: int = 100000, - maxsize: float = 3e9, - post: Optional[Callable] = None, - start_shard: int = 0, - **kw, - ): + self, + pattern: str, + maxcount: int=100000, + maxsize: float=3e9, + post: Optional[Callable]=None, + start_shard: int=0, + **kw, ): """Create a ShardWriter. :param pattern: output file pattern @@ -400,8 +407,7 @@ class ShardWriter: self.fname, self.count, "%.1f GB" % (self.size / 1e9), - self.total, - ) + self.total, ) self.shard += 1 stream = open(self.fname, "wb") self.tarstream = TarWriter(stream, **self.kw) @@ -413,11 +419,8 @@ class ShardWriter: :param obj: sample to be written """ - if ( - self.tarstream is None - or self.count >= self.maxcount - or self.size >= self.maxsize - ): + if (self.tarstream is None or self.count >= self.maxcount or + self.size >= self.maxsize): self.next_stream() size = self.tarstream.write(obj) self.count += 1 diff --git a/paddlespeech/audio/text/text_featurizer.py b/paddlespeech/audio/text/text_featurizer.py index 91c4d75c3..bcd6df54b 100644 --- a/paddlespeech/audio/text/text_featurizer.py +++ b/paddlespeech/audio/text/text_featurizer.py @@ -17,6 +17,7 @@ from typing import Union import sentencepiece as spm +from ..utils.log import Logger from .utility import BLANK from .utility import EOS from .utility import load_dict @@ -24,7 +25,6 @@ from .utility import MASKCTC from .utility import SOS from .utility import SPACE from .utility import UNK -from ..utils.log import Logger logger = Logger(__name__) diff --git a/paddlespeech/audio/transform/perturb.py b/paddlespeech/audio/transform/perturb.py index 8044dc36f..0825caec8 100644 --- a/paddlespeech/audio/transform/perturb.py +++ b/paddlespeech/audio/transform/perturb.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) +import io +import os + +import h5py import librosa import numpy +import numpy as np import scipy import soundfile -import io -import os -import h5py -import numpy as np class SoundHDF5File(): """Collecting sound files to a HDF5 file @@ -109,6 +110,7 @@ class SoundHDF5File(): def close(self): self.file.close() + class SpeedPerturbation(): """SpeedPerturbation @@ -558,4 +560,3 @@ class RIRConvolve(): [scipy.convolve(x, r, mode="same") for r in rir], axis=-1) else: return scipy.convolve(x, rir, mode="same") - diff --git a/paddlespeech/audio/transform/spec_augment.py b/paddlespeech/audio/transform/spec_augment.py index 029e7b8f5..b2635066f 100644 --- a/paddlespeech/audio/transform/spec_augment.py +++ b/paddlespeech/audio/transform/spec_augment.py @@ -14,6 +14,7 @@ # Modified from espnet(https://github.com/espnet/espnet) """Spec Augment module for preprocessing i.e., data augmentation""" import random + import numpy from PIL import Image diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 3800c36db..b53eed88c 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -191,7 +191,7 @@ class BaseExecutor(ABC): line = line.strip() if not line: continue - k, v = line.split() # space or \t + k, v = line.split() # space or \t job_contents[k] = v return job_contents diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index f6476b9aa..5fe2e16b9 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -114,6 +114,7 @@ if not hasattr(paddle.Tensor, 'new_full'): paddle.Tensor.new_full = new_full paddle.static.Variable.new_full = new_full + def contiguous(xs: paddle.Tensor) -> paddle.Tensor: return xs diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index cdad3b8f7..4978aab75 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -26,8 +26,8 @@ from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader -from paddlespeech.s2t.io.dataloader import StreamDataLoader from paddlespeech.s2t.io.dataloader import DataLoaderFactory +from paddlespeech.s2t.io.dataloader import StreamDataLoader from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope @@ -109,7 +109,8 @@ class U2Trainer(Trainer): def valid(self): self.model.eval() if not self.use_streamdata: - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -136,7 +137,8 @@ class U2Trainer(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) if not self.use_streamdata: - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += "batch: {}/{}, ".format(i + 1, + len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -157,7 +159,8 @@ class U2Trainer(Trainer): self.before_train() if not self.use_streamdata: - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -225,14 +228,18 @@ class U2Trainer(Trainer): config = self.config.clone() self.use_streamdata = config.get("use_stream_data", False) if self.train: - self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) - self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + self.train_loader = DataLoaderFactory.get_dataloader( + 'train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader( + 'valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) - self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) - self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, + self.args) + self.align_loader = DataLoaderFactory.get_dataloader( + 'align', config, self.args) logger.info("Setup test/align Dataloader!") def setup_model(self): diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index cb015c116..073d74293 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -105,7 +105,8 @@ class U2Trainer(Trainer): def valid(self): self.model.eval() if not self.use_streamdata: - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -133,7 +134,8 @@ class U2Trainer(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) if not self.use_streamdata: - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += "batch: {}/{}, ".format(i + 1, + len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -153,7 +155,8 @@ class U2Trainer(Trainer): self.before_train() if not self.use_streamdata: - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -165,8 +168,8 @@ class U2Trainer(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) if not self.use_streamdata: - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) + msg += "batch : {}/{}, ".format( + batch_index + 1, len(self.train_loader)) msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "data time: {:>.3f}s, ".format(dataload_time) self.train_batch(batch_index, batch, msg) @@ -204,21 +207,24 @@ class U2Trainer(Trainer): self.use_streamdata = config.get("use_stream_data", False) if self.train: config = self.config.clone() - self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + self.train_loader = DataLoaderFactory.get_dataloader( + 'train', config, self.args) config = self.config.clone() config['preprocess_config'] = None - self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader( + 'valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: config = self.config.clone() config['preprocess_config'] = None - self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, + self.args) config = self.config.clone() config['preprocess_config'] = None - self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) + self.align_loader = DataLoaderFactory.get_dataloader( + 'align', config, self.args) logger.info("Setup test/align Dataloader!") - def setup_model(self): config = self.config diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 603825435..d57c49546 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -121,7 +121,8 @@ class U2STTrainer(Trainer): def valid(self): self.model.eval() if not self.use_streamdata: - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -155,7 +156,8 @@ class U2STTrainer(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) if not self.use_streamdata: - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += "batch: {}/{}, ".format(i + 1, + len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -175,7 +177,8 @@ class U2STTrainer(Trainer): self.before_train() if not self.use_streamdata: - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -248,14 +251,16 @@ class U2STTrainer(Trainer): config['load_transcript'] = load_transcript self.use_streamdata = config.get("use_stream_data", False) if self.train: - self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) - self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + self.train_loader = DataLoaderFactory.get_dataloader( + 'train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader( + 'valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: - self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, + self.args) logger.info("Setup test Dataloader!") - def setup_model(self): config = self.config model_conf = config diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 831830241..3aff5f59b 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -22,17 +22,16 @@ import paddle from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler +from yacs.config import CfgNode +import paddlespeech.audio.streamdata as streamdata +from paddlespeech.audio.text.text_featurizer import TextFeaturizer from paddlespeech.s2t.io.batchfy import make_batchset from paddlespeech.s2t.io.converter import CustomConverter from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.utils.log import Log -import paddlespeech.audio.streamdata as streamdata -from paddlespeech.audio.text.text_featurizer import TextFeaturizer -from yacs.config import CfgNode - __all__ = ["BatchDataLoader", "StreamDataLoader"] logger = Log(__name__).getlog() @@ -61,6 +60,7 @@ def batch_collate(x): """ return x[0] + def read_preprocess_cfg(preprocess_conf_file): augment_conf = dict() preprocess_cfg = CfgNode(new_allowed=True) @@ -82,7 +82,8 @@ def read_preprocess_cfg(preprocess_conf_file): augment_conf['num_t_mask'] = process['n_mask'] augment_conf['t_inplace'] = process['inplace'] augment_conf['t_replace_with_zero'] = process['replace_with_zero'] - return augment_conf + return augment_conf + class StreamDataLoader(): def __init__(self, @@ -95,12 +96,12 @@ class StreamDataLoader(): frame_length=25, frame_shift=10, dither=0.0, - minlen_in: float=0.0, + minlen_in: float=0.0, maxlen_in: float=float('inf'), minlen_out: float=0.0, maxlen_out: float=float('inf'), resample_rate: int=16000, - shuffle_size: int=10000, + shuffle_size: int=10000, sort_size: int=1000, n_iter_processes: int=1, prefetch_factor: int=2, @@ -116,11 +117,11 @@ class StreamDataLoader(): text_featurizer = TextFeaturizer(unit_type, vocab_filepath) symbol_table = text_featurizer.vocab_dict - self.feat_dim = num_mel_bins - self.vocab_size = text_featurizer.vocab_size - + self.feat_dim = num_mel_bins + self.vocab_size = text_featurizer.vocab_size + augment_conf = read_preprocess_cfg(preprocess_conf) - + # The list of shard shardlist = [] with open(manifest_file, "r") as f: @@ -128,58 +129,68 @@ class StreamDataLoader(): shardlist.append(line.strip()) world_size = 1 try: - world_size = paddle.distributed.get_world_size() + world_size = paddle.distributed.get_world_size() except Exception as e: logger.warninig(e) - logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1") - assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...") + logger.warninig( + "can not get world_size using paddle.distributed.get_world_size(), use world_size=1" + ) + assert (len(shardlist) >= world_size, + "the length of shard list should >= number of gpus/xpus/...") - update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0)) + update_n_iter_processes = int( + max(min(len(shardlist) / world_size - 1, self.n_iter_processes), 0)) logger.info(f"update_n_iter_processes {update_n_iter_processes}") if update_n_iter_processes != self.n_iter_processes: - self.n_iter_processes = update_n_iter_processes + self.n_iter_processes = update_n_iter_processes logger.info(f"change nun_workers to {self.n_iter_processes}") if self.dist_sampler: base_dataset = streamdata.DataPipeline( - streamdata.SimpleShardList(shardlist), - streamdata.split_by_node if train_mode else streamdata.placeholder(), + streamdata.SimpleShardList(shardlist), streamdata.split_by_node + if train_mode else streamdata.placeholder(), streamdata.split_by_worker, - streamdata.tarfile_to_samples(streamdata.reraise_exception) - ) + streamdata.tarfile_to_samples(streamdata.reraise_exception)) else: base_dataset = streamdata.DataPipeline( streamdata.SimpleShardList(shardlist), streamdata.split_by_worker, - streamdata.tarfile_to_samples(streamdata.reraise_exception) - ) + streamdata.tarfile_to_samples(streamdata.reraise_exception)) self.dataset = base_dataset.append_list( streamdata.audio_tokenize(symbol_table), - streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out), + streamdata.audio_data_filter( + frame_shift=frame_shift, + max_length=maxlen_in, + min_length=minlen_in, + token_max_length=maxlen_out, + token_min_length=minlen_out), streamdata.audio_resample(resample_rate=resample_rate), - streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), - streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) + streamdata.audio_compute_fbank( + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither), + streamdata.audio_spec_aug(**augment_conf) + if train_mode else streamdata.placeholder( + ), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) streamdata.shuffle(shuffle_size), streamdata.sort(sort_size=sort_size), streamdata.batched(batch_size), streamdata.audio_padding(), - streamdata.audio_cmvn(cmvn_file) - ) + streamdata.audio_cmvn(cmvn_file)) if paddle.__version__ >= '2.3.2': self.loader = streamdata.WebLoader( - self.dataset, - num_workers=self.n_iter_processes, - prefetch_factor = self.prefetch_factor, - batch_size=None - ) + self.dataset, + num_workers=self.n_iter_processes, + prefetch_factor=self.prefetch_factor, + batch_size=None) else: self.loader = streamdata.WebLoader( - self.dataset, - num_workers=self.n_iter_processes, - batch_size=None - ) + self.dataset, + num_workers=self.n_iter_processes, + batch_size=None) def __iter__(self): return self.loader.__iter__() @@ -188,7 +199,9 @@ class StreamDataLoader(): return self.__iter__() def __len__(self): - logger.info("Stream dataloader does not support calculate the length of the dataset") + logger.info( + "Stream dataloader does not support calculate the length of the dataset" + ) return -1 @@ -347,7 +360,7 @@ class DataLoaderFactory(): config['train_mode'] = True elif mode == 'valid': config['manifest'] = config.dev_manifest - config['train_mode'] = False + config['train_mode'] = False elif model == 'test' or mode == 'align': config['manifest'] = config.test_manifest config['train_mode'] = False @@ -358,30 +371,31 @@ class DataLoaderFactory(): config['maxlen_out'] = float('inf') config['dist_sampler'] = False else: - raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") - return StreamDataLoader( - manifest_file=config.manifest, - train_mode=config.train_mode, - unit_type=config.unit_type, - preprocess_conf=config.preprocess_config, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=config.dither, - minlen_in=config.minlen_in, - maxlen_in=config.maxlen_in, - minlen_out=config.minlen_out, - maxlen_out=config.maxlen_out, - resample_rate=config.resample_rate, - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.dist_sampler, - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, + raise KeyError( + "not valid mode type!!, please input one of 'train, valid, test, align'" ) + return StreamDataLoader( + manifest_file=config.manifest, + train_mode=config.train_mode, + unit_type=config.unit_type, + preprocess_conf=config.preprocess_config, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=config.dither, + minlen_in=config.minlen_in, + maxlen_in=config.maxlen_in, + minlen_out=config.minlen_out, + maxlen_out=config.maxlen_out, + resample_rate=config.resample_rate, + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.dist_sampler, + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, ) else: if mode == 'train': config['manifest'] = config.train_manifest @@ -410,7 +424,7 @@ class DataLoaderFactory(): config['train_mode'] = False config['sortagrad'] = False config['batch_size'] = config.get('decode', dict()).get( - 'decode_batch_size', 1) + 'decode_batch_size', 1) config['maxlen_in'] = float('inf') config['maxlen_out'] = float('inf') config['minibatches'] = 0 @@ -426,8 +440,10 @@ class DataLoaderFactory(): config['dist_sampler'] = False config['shortest_first'] = False else: - raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") - + raise KeyError( + "not valid mode type!!, please input one of 'train, valid, test, align'" + ) + return BatchDataLoader( json_file=config.manifest, train_mode=config.train_mode, @@ -449,4 +465,3 @@ class DataLoaderFactory(): num_encs=config.num_encs, dist_sampler=config.dist_sampler, shortest_first=config.shortest_first) - diff --git a/paddlespeech/s2t/io/sampler.py b/paddlespeech/s2t/io/sampler.py index ac55af123..89752bb9f 100644 --- a/paddlespeech/s2t/io/sampler.py +++ b/paddlespeech/s2t/io/sampler.py @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): """ rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) - batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) + batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] assert clipped is False diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index e86bbedfa..8a811a52b 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -26,6 +26,8 @@ import paddle from paddle import jit from paddle import nn +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.modules.cmvn import GlobalCMVN @@ -38,8 +40,6 @@ from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.log import Log -from paddlespeech.audio.utils.tensor_utils import add_sos_eos -from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ["U2STModel", "U2STInferModel"] @@ -401,8 +401,8 @@ class U2STBaseModel(nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), - cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Export interface for c++ call, give input chunk xs, and return output from time 0 to current chunk. @@ -435,8 +435,8 @@ class U2STBaseModel(nn.Layer): paddle.Tensor: new conformer cnn cache required for next chunk, with same shape as the original cnn_cache. """ - return self.encoder.forward_chunk( - xs, offset, required_cache_size, att_cache, cnn_cache) + return self.encoder.forward_chunk(xs, offset, required_cache_size, + att_cache, cnn_cache) # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: diff --git a/paddlespeech/s2t/modules/align.py b/paddlespeech/s2t/modules/align.py index cacda2461..34d796145 100644 --- a/paddlespeech/s2t/modules/align.py +++ b/paddlespeech/s2t/modules/align.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math + import paddle from paddle import nn -import math """ To align the initializer between paddle and torch, the API below are set defalut initializer with priority higger than global initializer. @@ -81,10 +82,18 @@ class Linear(nn.Linear): name=None): if weight_attr is None: if global_init_type == "kaiming_uniform": - weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) if bias_attr is None: if global_init_type == "kaiming_uniform": - bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) super(Linear, self).__init__(in_features, out_features, weight_attr, bias_attr, name) @@ -104,10 +113,18 @@ class Conv1D(nn.Conv1D): data_format='NCL'): if weight_attr is None: if global_init_type == "kaiming_uniform": - weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) if bias_attr is None: if global_init_type == "kaiming_uniform": - bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) super(Conv1D, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, padding_mode, weight_attr, bias_attr, data_format) @@ -128,10 +145,18 @@ class Conv2D(nn.Conv2D): data_format='NCHW'): if weight_attr is None: if global_init_type == "kaiming_uniform": - weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) if bias_attr is None: if global_init_type == "kaiming_uniform": - bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) super(Conv2D, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, padding_mode, weight_attr, bias_attr, data_format) diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py index b6d615867..ee6dd7fa3 100644 --- a/paddlespeech/s2t/modules/attention.py +++ b/paddlespeech/s2t/modules/attention.py @@ -83,11 +83,12 @@ class MultiHeadedAttention(nn.Layer): return q, k, v - def forward_attention(self, - value: paddle.Tensor, + def forward_attention( + self, + value: paddle.Tensor, scores: paddle.Tensor, - mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool), - ) -> paddle.Tensor: + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + ) -> paddle.Tensor: """Compute attention context vector. Args: value (paddle.Tensor): Transformed value, size @@ -108,7 +109,7 @@ class MultiHeadedAttention(nn.Layer): # When will `if mask.size(2) > 0` be False? # 1. onnx(16/-1, -1/-1, 16/0) # 2. jit (16/-1, -1/-1, 16/0, 16/4) - if paddle.shape(mask)[2] > 0: # time2 > 0 + if paddle.shape(mask)[2] > 0: # time2 > 0 mask = mask.unsqueeze(1).equal(0) # (batch, 1, *, time2) # for last chunk, time2 might be larger than scores.size(-1) mask = mask[:, :, :, :paddle.shape(scores)[-1]] @@ -131,9 +132,9 @@ class MultiHeadedAttention(nn.Layer): query: paddle.Tensor, key: paddle.Tensor, value: paddle.Tensor, - mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), - pos_emb: paddle.Tensor = paddle.empty([0]), - cache: paddle.Tensor = paddle.zeros([0,0,0,0]) + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + pos_emb: paddle.Tensor=paddle.empty([0]), + cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute scaled dot product attention. Args: @@ -247,9 +248,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): query: paddle.Tensor, key: paddle.Tensor, value: paddle.Tensor, - mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), - pos_emb: paddle.Tensor = paddle.empty([0]), - cache: paddle.Tensor = paddle.zeros([0,0,0,0]) + mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + pos_emb: paddle.Tensor=paddle.empty([0]), + cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]) ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py index c384b9c78..837ad470d 100644 --- a/paddlespeech/s2t/modules/conformer_convolution.py +++ b/paddlespeech/s2t/modules/conformer_convolution.py @@ -106,11 +106,12 @@ class ConvolutionModule(nn.Layer): ) self.activation = activation - def forward(self, - x: paddle.Tensor, - mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool), - cache: paddle.Tensor= paddle.zeros([0,0,0]), - ) -> Tuple[paddle.Tensor, paddle.Tensor]: + def forward( + self, + x: paddle.Tensor, + mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), + cache: paddle.Tensor=paddle.zeros([0, 0, 0]), + ) -> Tuple[paddle.Tensor, paddle.Tensor]: """Compute convolution module. Args: x (paddle.Tensor): Input tensor (#batch, time, channels). @@ -127,11 +128,11 @@ class ConvolutionModule(nn.Layer): x = x.transpose([0, 2, 1]) # [B, C, T] # mask batch padding - if paddle.shape(mask_pad)[2] > 0: # time > 0 + if paddle.shape(mask_pad)[2] > 0: # time > 0 x = x.masked_fill(mask_pad, 0.0) if self.lorder > 0: - if paddle.shape(cache)[2] == 0: # cache_t == 0 + if paddle.shape(cache)[2] == 0: # cache_t == 0 x = nn.functional.pad( x, [self.lorder, 0], 'constant', 0.0, data_format='NCL') else: @@ -161,7 +162,7 @@ class ConvolutionModule(nn.Layer): x = self.pointwise_conv2(x) # mask batch padding - if paddle.shape(mask_pad)[2] > 0: # time > 0 + if paddle.shape(mask_pad)[2] > 0: # time > 0 x = x.masked_fill(mask_pad, 0.0) x = x.transpose([0, 2, 1]) # [B, T, C] diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py index bff2d69bb..dc7b6666a 100644 --- a/paddlespeech/s2t/modules/encoder.py +++ b/paddlespeech/s2t/modules/encoder.py @@ -190,9 +190,9 @@ class BaseEncoder(nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]), - cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]), - att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool), ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Forward just one chunk Args: @@ -227,7 +227,7 @@ class BaseEncoder(nn.Layer): xs = self.global_cmvn(xs) # before embed, xs=(B, T, D1), pos_emb=(B=1, T, D) - xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset) # after embed, xs=(B=1, chunk_size, hidden-dim) elayers = paddle.shape(att_cache)[0] @@ -252,14 +252,16 @@ class BaseEncoder(nn.Layer): # att_cache[i:i+1] = (1, head, cache_t1, d_k*2) # cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2) xs, _, new_att_cache, new_cnn_cache = layer( - xs, att_mask, pos_emb, - att_cache=att_cache[i:i+1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, - ) + xs, + att_mask, + pos_emb, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i:i + 1] + if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, ) # new_att_cache = (1, head, attention_key_size, d_k*2) # new_cnn_cache = (B=1, hidden-dim, cache_t2) - r_att_cache.append(new_att_cache[:,:, next_cache_start:, :]) - r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim if self.normalize_before: xs = self.after_norm(xs) @@ -270,7 +272,6 @@ class BaseEncoder(nn.Layer): r_cnn_cache = paddle.concat(r_cnn_cache, axis=0) return xs, r_att_cache, r_cnn_cache - def forward_chunk_by_chunk( self, xs: paddle.Tensor, @@ -315,8 +316,8 @@ class BaseEncoder(nn.Layer): num_frames = xs.shape[1] required_cache_size = decoding_chunk_size * num_decoding_left_chunks - att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]) - cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]) + att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]) + cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]) outputs = [] offset = 0 @@ -326,7 +327,7 @@ class BaseEncoder(nn.Layer): chunk_xs = xs[:, cur:end, :] (y, att_cache, cnn_cache) = self.forward_chunk( - chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + chunk_xs, offset, required_cache_size, att_cache, cnn_cache) outputs.append(y) offset += y.shape[1] diff --git a/paddlespeech/s2t/modules/initializer.py b/paddlespeech/s2t/modules/initializer.py index cdcf2e052..e37837d2f 100644 --- a/paddlespeech/s2t/modules/initializer.py +++ b/paddlespeech/s2t/modules/initializer.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np + class DefaultInitializerContext(object): """ egs: diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py index 175e8ffb6..11f50655f 100644 --- a/paddlespeech/server/bin/paddlespeech_server.py +++ b/paddlespeech/server/bin/paddlespeech_server.py @@ -18,7 +18,6 @@ from typing import List import uvicorn from fastapi import FastAPI -from starlette.middleware.cors import CORSMiddleware from prettytable import PrettyTable from starlette.middleware.cors import CORSMiddleware @@ -46,6 +45,7 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"]) + @cli_server_register( name='paddlespeech_server.start', description='Start the service') class ServerExecutor(BaseExecutor): diff --git a/paddlespeech/server/engine/asr/online/ctc_endpoint.py b/paddlespeech/server/engine/asr/online/ctc_endpoint.py index b87dbe805..1b8ad1cb7 100644 --- a/paddlespeech/server/engine/asr/online/ctc_endpoint.py +++ b/paddlespeech/server/engine/asr/online/ctc_endpoint.py @@ -102,8 +102,10 @@ class OnlineCTCEndpoint: assert self.num_frames_decoded >= self.trailing_silence_frames assert self.frame_shift_in_ms > 0 - - decoding_something = (self.num_frames_decoded > self.trailing_silence_frames) and decoding_something + + decoding_something = ( + self.num_frames_decoded > self.trailing_silence_frames + ) and decoding_something utterance_length = self.num_frames_decoded * self.frame_shift_in_ms trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms diff --git a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py index ab4f11305..6daae5be3 100644 --- a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py @@ -21,12 +21,12 @@ import paddle from numpy import float32 from yacs.config import CfgNode +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils import onnx_infer diff --git a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py index 182e64180..0fd5d1bc6 100644 --- a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py @@ -21,10 +21,10 @@ import paddle from numpy import float32 from yacs.config import CfgNode +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource -from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.utility import UpdateConfig diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 4df38f09d..269a33ba7 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -21,10 +21,10 @@ import paddle from numpy import float32 from yacs.config import CfgNode +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource -from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.tensor_utils import add_sos_eos @@ -130,8 +130,8 @@ class PaddleASRConnectionHanddler: ## conformer # cache for conformer online - self.att_cache = paddle.zeros([0,0,0,0]) - self.cnn_cache = paddle.zeros([0,0,0,0]) + self.att_cache = paddle.zeros([0, 0, 0, 0]) + self.cnn_cache = paddle.zeros([0, 0, 0, 0]) self.encoder_out = None # conformer decoding state @@ -474,9 +474,10 @@ class PaddleASRConnectionHanddler: # cur chunk chunk_xs = self.cached_feat[:, cur:end, :] # forward chunk - (y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk( - chunk_xs, self.offset, required_cache_size, - self.att_cache, self.cnn_cache) + (y, self.att_cache, + self.cnn_cache) = self.model.encoder.forward_chunk( + chunk_xs, self.offset, required_cache_size, self.att_cache, + self.cnn_cache) outputs.append(y) # update the global offset, in decoding frame unit diff --git a/paddlespeech/t2s/exps/ernie_sat/align.py b/paddlespeech/t2s/exps/ernie_sat/align.py index 529a8221c..513c57e70 100755 --- a/paddlespeech/t2s/exps/ernie_sat/align.py +++ b/paddlespeech/t2s/exps/ernie_sat/align.py @@ -19,9 +19,9 @@ import librosa import numpy as np import pypinyin from praatio import textgrid -from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name -from paddlespeech.t2s.exps.ernie_sat.utils import get_dict +from paddlespeech.t2s.exps.ernie_sat.utils import get_dict +from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name DICT_EN = 'tools/aligner/cmudict-0.7b' DICT_ZH = 'tools/aligner/simple.lexicon' @@ -30,6 +30,7 @@ MODEL_DIR_ZH = 'tools/aligner/aishell3_model.zip' MFA_PATH = 'tools/montreal-forced-aligner/bin' os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH'] + def _get_max_idx(dic): return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] @@ -106,11 +107,11 @@ def alignment(wav_path: str, wav_name = os.path.basename(wav_path) utt = wav_name.split('.')[0] # prepare data for MFA - tmp_name = get_tmp_name(text=text) + tmp_name = get_tmp_name(text=text) tmpbase = './tmp_dir/' + tmp_name tmpbase = Path(tmpbase) tmpbase.mkdir(parents=True, exist_ok=True) - print("tmp_name in alignment:",tmp_name) + print("tmp_name in alignment:", tmp_name) shutil.copyfile(wav_path, tmpbase / wav_name) txt_name = utt + '.txt' diff --git a/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py index 95b07367c..fac4cef87 100644 --- a/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py @@ -15,31 +15,24 @@ import librosa import numpy as np import soundfile as sf +from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn +from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy -from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn +from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name from paddlespeech.t2s.exps.syn_utils import get_frontend -from paddlespeech.t2s.datasets.get_feats import LogMelFBank from paddlespeech.t2s.exps.syn_utils import norm -from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name - - - - def _p2id(self, phonemes: List[str]) -> np.ndarray: # replace unk phone with sp - phonemes = [ - phn if phn in vocab_phones else "sp" for phn in phonemes - ] + phonemes = [phn if phn in vocab_phones else "sp" for phn in phonemes] phone_ids = [vocab_phones[item] for item in phonemes] return np.array(phone_ids, np.int64) - def prep_feats_with_dur(wav_path: str, old_str: str='', new_str: str='', @@ -152,8 +145,6 @@ def prep_feats_with_dur(wav_path: str, return outs - - def prep_feats(wav_path: str, old_str: str='', new_str: str='', @@ -188,32 +179,32 @@ def prep_feats(wav_path: str, mel = mel_extractor.get_log_mel_fbank(wav) erniesat_mean, erniesat_std = np.load(erniesat_stat) normed_mel = norm(mel, erniesat_mean, erniesat_std) - tmp_name = get_tmp_name(text=old_str) + tmp_name = get_tmp_name(text=old_str) tmpbase = './tmp_dir/' + tmp_name tmpbase = Path(tmpbase) tmpbase.mkdir(parents=True, exist_ok=True) - print("tmp_name in synthesize_e2e:",tmp_name) + print("tmp_name in synthesize_e2e:", tmp_name) mel_path = tmpbase / 'mel.npy' - print("mel_path:",mel_path) + print("mel_path:", mel_path) np.save(mel_path, logmel) durations = [e - s for e, s in zip(mfa_end, mfa_start)] - datum={ - "utt_id": utt_id, - "spk_id": 0, - "text": text, - "text_lengths": len(text), - "speech_lengths": 115, - "durations": durations, - "speech": mel_path, - "align_start": mfa_start, + datum = { + "utt_id": utt_id, + "spk_id": 0, + "text": text, + "text_lengths": len(text), + "speech_lengths": 115, + "durations": durations, + "speech": mel_path, + "align_start": mfa_start, "align_end": mfa_end, "span_bdy": span_bdy } batch = collate_fn([datum]) - print("batch:",batch) + print("batch:", batch) return batch, old_span_bdy, new_span_bdy @@ -240,8 +231,6 @@ def decode_with_model(mlm_model: nn.Layer, fs=fs, n_shift=n_shift, token_list=token_list) - - feats = collate_fn(batch)[1] @@ -275,7 +264,6 @@ if __name__ == '__main__': # for synthesize append_str = "do you love me i love you so much" new_str = old_str + append_str - ''' outs = prep_feats_with_dur( wav_path=wav_path, @@ -306,7 +294,7 @@ if __name__ == '__main__': with open(erniesat_config) as f: erniesat_config = CfgNode(yaml.safe_load(f)) - + erniesat_stat = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy" # Extractor @@ -319,16 +307,14 @@ if __name__ == '__main__': n_mels=erniesat_config.n_mels, fmin=erniesat_config.fmin, fmax=erniesat_config.fmax) - - collate_fn = build_erniesat_collate_fn( mlm_prob=erniesat_config.mlm_prob, mean_phn_span=erniesat_config.mean_phn_span, seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', text_masking=False) - - phones_dict='/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt' + + phones_dict = '/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt' vocab_phones = {} with open(phones_dict, 'rt') as f: @@ -336,11 +322,9 @@ if __name__ == '__main__': for phn, id in phn_id: vocab_phones[phn] = int(id) - prep_feats(wav_path=wav_path, - old_str=old_str, - new_str=new_str, - fs=fs, - n_shift=n_shift) - - - + prep_feats( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + fs=fs, + n_shift=n_shift) diff --git a/paddlespeech/t2s/exps/ernie_sat/utils.py b/paddlespeech/t2s/exps/ernie_sat/utils.py index 9169efa36..6805e513c 100644 --- a/paddlespeech/t2s/exps/ernie_sat/utils.py +++ b/paddlespeech/t2s/exps/ernie_sat/utils.py @@ -11,32 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import os from pathlib import Path from typing import Dict from typing import List from typing import Union -import os import numpy as np import paddle import yaml from yacs.config import CfgNode -import hashlib - from paddlespeech.t2s.exps.syn_utils import get_am_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference + def _get_user(): return os.path.expanduser('~').split('/')[-1] + def str2md5(string): md5_val = hashlib.md5(string.encode('utf8')).hexdigest() return md5_val -def get_tmp_name(text:str): + +def get_tmp_name(text: str): return _get_user() + '_' + str(os.getpid()) + '_' + str2md5(text) + def get_dict(dictfile: str): word2phns_dict = {} with open(dictfile, 'r') as fid: diff --git a/paddlespeech/t2s/exps/tacotron2/normalize.py b/paddlespeech/t2s/exps/tacotron2/normalize.py index 64848f899..6c250a3b9 120000 --- a/paddlespeech/t2s/exps/tacotron2/normalize.py +++ b/paddlespeech/t2s/exps/tacotron2/normalize.py @@ -1 +1 @@ -../transformer_tts/normalize.py \ No newline at end of file +#../transformer_tts/normalize.py diff --git a/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py b/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py index 87f6c7cff..dfd9e6352 100644 --- a/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py +++ b/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import os from concurrent.futures import ThreadPoolExecutor from operator import itemgetter from pathlib import Path @@ -24,165 +25,15 @@ import librosa import numpy as np import tqdm import yaml -from yacs.config import CfgNode as Configuration -import re -from paddlespeech.t2s.datasets.get_feats import LogMelFBank -from paddlespeech.t2s.frontend import English,Chinese - - -def get_lj_sentences(file_name, frontend): - '''read MFA duration.txt - - Args: - file_name (str or Path) - Returns: - Dict: sentence: {'utt': ([char], [int])} - ''' - f = open(file_name, 'r') - sentence = {} - speaker_set = set() - for line in f: - line_list = line.strip().split('|') - utt = line_list[0] - speaker = utt.split("-")[0][:2] - speaker_set.add(speaker) - raw_text = line_list[-1] - phonemes = frontend.phoneticize(raw_text) - phonemes = phonemes[1:-1] - phonemes = [phn for phn in phonemes if not phn.isspace()] - sentence[utt] = (phonemes, speaker) - f.close() - return sentence, speaker_set - -def get_csmsc_sentences(file_name,fronten): - '''read MFA duration.txt - - Args: - file_name (str or Path) - Returns: - Dict: sentence: {'utt': ([char], [int])} - ''' - sentence = {} - speaker_set = set() - utt = 'girl' - with open(file_name, mode='r', encoding='utf-8') as f: - lines = f.readlines() - for i in range(len(lines)): - ann = lines[i] - if i % 2 == 0: - head = ann.strip('\n|\t').split('\t') - body = re.sub(r'[0-9]|#', '', head[-1]) - phonemes = fronten.phoneticize(body) - phonemes = phonemes[1:-1] - phonemes = [phn for phn in phonemes if not phn.isspace()] - sentence[head[0]] = (phonemes, utt) - speaker_set.add(utt) - f.close() - return sentence, speaker_set - -def get_input_token(sentence, output_path): - '''get phone set from training data and save it - - Args: - sentence (Dict): sentence: {'utt': ([char], str)} - output_path (str or path): path to save phone_id_map - ''' - phn_token = set() - for utt in sentence: - for phn in sentence[utt][0]: - if phn != "": - phn_token.add(phn) - phn_token = list(phn_token) - phn_token.sort() - phn_token = ["", ""] + phn_token - phn_token += [""] - - with open(output_path, 'w') as f: - for i, phn in enumerate(phn_token): - f.write(phn + ' ' + str(i) + '\n') - - -def get_spk_id_map(speaker_set, output_path): - speakers = sorted(list(speaker_set)) - with open(output_path, 'w') as f: - for i, spk in enumerate(speakers): - f.write(spk + ' ' + str(i) + '\n') - - -def process_sentence(config: Dict[str, Any], - fp: Path, - sentences: Dict, - output_dir: Path, - mel_extractor=None): - utt_id = fp.stem - record = None - if utt_id in sentences: - # reading, resampling may occur - wav, _ = librosa.load(str(fp), sr=config.fs) - if len(wav.shape) != 1 or np.abs(wav).max() > 1.0: - return record - assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio." - assert np.abs(wav).max( - ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." - phones = sentences[utt_id][0] - speaker = sentences[utt_id][1] - logmel = mel_extractor.get_log_mel_fbank(wav, base='e') - # change duration according to mel_length - num_frames = logmel.shape[0] - mel_dir = output_dir / "data_speech" - mel_dir.mkdir(parents=True, exist_ok=True) - mel_path = mel_dir / (utt_id + "_speech.npy") - np.save(mel_path, logmel) - record = { - "utt_id": utt_id, - "phones": phones, - "text_lengths": len(phones), - "speech_lengths": num_frames, - "speech": str(mel_path), - "speaker": speaker - } - return record - - -def process_sentences(config, - fps: List[Path], - sentences: Dict, - output_dir: Path, - mel_extractor=None, - nprocs: int=1): - - if nprocs == 1: - results = [] - for fp in tqdm.tqdm(fps, total=len(fps)): - record = process_sentence( - config=config, - fp=fp, - sentences=sentences, - output_dir=output_dir, - mel_extractor=mel_extractor) - if record: - results.append(record) - else: - with ThreadPoolExecutor(nprocs) as pool: - futures = [] - with tqdm.tqdm(total=len(fps)) as progress: - for fp in fps: - future = pool.submit(process_sentence, config, fp, - sentences, output_dir, mel_extractor) - future.add_done_callback(lambda p: progress.update()) - futures.append(future) - - results = [] - for ft in futures: - record = ft.result() - if record: - results.append(record) +from yacs.config import CfgNode - results.sort(key=itemgetter("utt_id")) - with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: - for item in results: - writer.write(item) - print("Done") +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length +from paddlespeech.t2s.datasets.preprocess_utils import get_input_token +from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur +from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map +from paddlespeech.t2s.datasets.preprocess_utils import merge_silence +from paddlespeech.t2s.utils import str2bool def main(): @@ -192,71 +43,113 @@ def main(): parser.add_argument( "--dataset", - default="csmsc", + default="baker", type=str, - help="name of dataset, should in {ljspeech,csmsc} now") + help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now") parser.add_argument( - "--rootdir", default='./BZNSYP/', type=str, help="directory to dataset.") + "--rootdir", default=None, type=str, help="directory to dataset.") parser.add_argument( "--dumpdir", type=str, - default='./dump/', - #required=True, + required=True, help="directory to dump feature files.") - parser.add_argument( - "--config-path", - default="./default.yaml", - type=str, - help="yaml format configuration file.") + "--dur-file", default=None, type=str, help="path to durations.txt.") + + parser.add_argument("--config", type=str, help="transformer config file.") parser.add_argument( "--num-cpu", type=int, default=1, help="number of process.") + parser.add_argument( + "--cut-sil", + type=str2bool, + default=True, + help="whether cut sil in the edge of audio") + + parser.add_argument( + "--spk_emb_dir", + default=None, + type=str, + help="directory to speaker embedding files.") args = parser.parse_args() - config_path = Path(args.config_path).resolve() - root_dir = Path(args.rootdir).expanduser() + + rootdir = Path(args.rootdir).expanduser() dumpdir = Path(args.dumpdir).expanduser() # use absolute path dumpdir = dumpdir.resolve() dumpdir.mkdir(parents=True, exist_ok=True) + dur_file = Path(args.dur_file).expanduser() - assert root_dir.is_dir() + if args.spk_emb_dir: + spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve() + else: + spk_emb_dir = None + + assert rootdir.is_dir() + assert dur_file.is_file() - with open(config_path, 'rt') as f: - _C = yaml.safe_load(f) - _C = Configuration(_C) - config = _C.clone() + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + sentences, speaker_set = get_phn_dur(dur_file) + + merge_silence(sentences) phone_id_map_path = dumpdir / "phone_id_map.txt" speaker_id_map_path = dumpdir / "speaker_id_map.txt" - if args.dataset == "csmsc": - wav_files = sorted(list((root_dir / "Wave").rglob("*.wav"))) - frontend = Chinese() - sentences, speaker_set = get_csmsc_sentences(root_dir / "000001-010000.txt", frontend) - print(speaker_set) - get_input_token(sentences, phone_id_map_path) - get_spk_id_map(speaker_set, speaker_id_map_path) - num_train = 9000 + get_input_token(sentences, phone_id_map_path, args.dataset) + get_spk_id_map(speaker_set, speaker_id_map_path) + + if args.dataset == "baker": + wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) + # split data into 3 sections + num_train = 9800 num_dev = 100 train_wav_files = wav_files[:num_train] dev_wav_files = wav_files[num_train:num_train + num_dev] test_wav_files = wav_files[num_train + num_dev:] - - if args.dataset == "ljspeech": - wav_files = sorted(list((root_dir / "wavs").rglob("*.wav"))) - frontend = English() - sentences, speaker_set = get_lj_sentences(root_dir / "metadata.csv",frontend) - get_input_token(sentences, phone_id_map_path) - get_spk_id_map(speaker_set, speaker_id_map_path) + elif args.dataset == "aishell3": + sub_num_dev = 5 + wav_dir = rootdir / "train" / "wav" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*.wav"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + elif args.dataset == "ljspeech": + wav_files = sorted(list((rootdir / "wavs").rglob("*.wav"))) # split data into 3 sections num_train = 12900 num_dev = 100 train_wav_files = wav_files[:num_train] dev_wav_files = wav_files[num_train:num_train + num_dev] test_wav_files = wav_files[num_train + num_dev:] + elif args.dataset == "vctk": + sub_num_dev = 5 + wav_dir = rootdir / "wav48_silence_trimmed" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + else: + print("dataset should in {baker, aishell3, ljspeech, vctk} now!") train_dump_dir = dumpdir / "train" / "raw" train_dump_dir.mkdir(parents=True, exist_ok=True) @@ -284,7 +177,9 @@ def main(): sentences=sentences, output_dir=train_dump_dir, mel_extractor=mel_extractor, - nprocs=args.num_cpu) + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) if dev_wav_files: process_sentences( config=config, @@ -292,7 +187,8 @@ def main(): sentences=sentences, output_dir=dev_dump_dir, mel_extractor=mel_extractor, - nprocs=args.num_cpu) + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) if test_wav_files: process_sentences( config=config, @@ -300,7 +196,9 @@ def main(): sentences=sentences, output_dir=test_dump_dir, mel_extractor=mel_extractor, - nprocs=args.num_cpu) + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) if __name__ == "__main__": diff --git a/paddlespeech/t2s/exps/transformer_tts/synthesize.py b/paddlespeech/t2s/exps/transformer_tts/synthesize.py index 7b6b1873f..b66bbdc62 100644 --- a/paddlespeech/t2s/exps/transformer_tts/synthesize.py +++ b/paddlespeech/t2s/exps/transformer_tts/synthesize.py @@ -50,11 +50,14 @@ def evaluate(args, acoustic_model_config, vocoder_config): model.set_state_dict( paddle.load(args.transformer_tts_checkpoint)["main_params"]) model.eval() + # remove ".pdparams" in waveflow_checkpoint - vocoder_checkpoint_path = args.waveflow_checkpoint[:-9] if args.waveflow_checkpoint.endswith( - ".pdparams") else args.waveflow_checkpoint - vocoder = ConditionalWaveFlow.from_pretrained(vocoder_config, - vocoder_checkpoint_path) + vocoder = get_voc_inference( + voc=args.voc, + voc_config=vocoder_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) + layer_tools.recursively_remove_weight_norm(vocoder) vocoder.eval() print("model done!") @@ -78,9 +81,8 @@ def evaluate(args, acoustic_model_config, vocoder_config): with paddle.no_grad(): mel = transformer_tts_inference(text) # mel shape is (T, feats) and waveflow's input shape is (batch, feats, T) - mel = mel.unsqueeze(0).transpose([0, 2, 1]) # wavflow's output shape is (B, T) - wav = vocoder.infer(mel)[0] + wav = vocoder(mel) sf.write( str(output_dir / (utt_id + ".wav")), @@ -106,18 +108,34 @@ def main(): type=str, help="mean and standard deviation used to normalize spectrogram when training transformer tts." ) + + # vocoder + parser.add_argument( + '--voc', + type=str, + default='pwgan_csmsc', + choices=[ + 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk', + 'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc', + 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk', + 'style_melgan_csmsc' + ], + help='Choose vocoder type of tts task.') parser.add_argument( - "--waveflow-config", type=str, help="waveflow config file.") - # not normalize when training waveflow + '--voc_config', type=str, default=None, help='Config of voc.') parser.add_argument( - "--waveflow-checkpoint", type=str, help="waveflow checkpoint to load.") + '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( - "--phones-dict", type=str, default=None, help="phone vocabulary file.") - - parser.add_argument("--test-metadata", type=str, help="test metadata.") - parser.add_argument("--output-dir", type=str, help="output dir.") + "--voc_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument("--test_metadata", type=str, help="test metadata.") + parser.add_argument("--output_dir", type=str, help="output dir.") args = parser.parse_args() @@ -130,16 +148,16 @@ def main(): with open(args.transformer_tts_config) as f: transformer_tts_config = CfgNode(yaml.safe_load(f)) - with open(args.waveflow_config) as f: - waveflow_config = CfgNode(yaml.safe_load(f)) + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) print("========Args========") print(yaml.safe_dump(vars(args))) print("========Config========") print(transformer_tts_config) - print(waveflow_config) + print(voc_config) - evaluate(args, transformer_tts_config, waveflow_config) + evaluate(args, transformer_tts_config, voc_config) if __name__ == "__main__": diff --git a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py index 0cd7d224e..716deb7bc 100644 --- a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py @@ -21,6 +21,7 @@ import soundfile as sf import yaml from yacs.config import CfgNode +from paddlespeech.t2s.frontend import Chinese from paddlespeech.t2s.frontend import English from paddlespeech.t2s.models.transformer_tts import TransformerTTS from paddlespeech.t2s.models.transformer_tts import TransformerTTSInference @@ -59,15 +60,15 @@ def evaluate(args, acoustic_model_config, vocoder_config): model.eval() # remove ".pdparams" in waveflow_checkpoint - vocoder_checkpoint_path = args.waveflow_checkpoint[:-9] if args.waveflow_checkpoint.endswith( - ".pdparams") else args.waveflow_checkpoint - vocoder = ConditionalWaveFlow.from_pretrained(vocoder_config, - vocoder_checkpoint_path) - layer_tools.recursively_remove_weight_norm(vocoder) + vocoder = get_voc_inference( + voc=args.voc, + voc_config=vocoder_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) vocoder.eval() print("model done!") - frontend = English() + frontend = Chinese() print("frontend done!") stat = np.load(args.transformer_tts_stat) @@ -90,11 +91,10 @@ def evaluate(args, acoustic_model_config, vocoder_config): phones = [phn if phn in phone_id_map else "," for phn in phones] phone_ids = [phone_id_map[phn] for phn in phones] with paddle.no_grad(): - mel = transformer_tts_inference(paddle.to_tensor(phone_ids)) - # mel shape is (T, feats) and waveflow's input shape is (batch, feats, T) - mel = mel.unsqueeze(0).transpose([0, 2, 1]) - # wavflow's output shape is (B, T) - wav = vocoder.infer(mel)[0] + tensor_phone_ids = paddle.to_tensor(phone_ids) + mel = transformer_tts_inference(tensor_phone_ids) + + wav = vocoder(mel) sf.write( str(output_dir / (utt_id + ".wav")), @@ -120,23 +120,51 @@ def main(): type=str, help="mean and standard deviation used to normalize spectrogram when training transformer tts." ) + + # vocoder + parser.add_argument( + '--voc', + type=str, + default='pwgan_csmsc', + choices=[ + 'pwgan_csmsc', + 'pwgan_ljspeech', + 'pwgan_aishell3', + 'pwgan_vctk', + 'mb_melgan_csmsc', + 'style_melgan_csmsc', + 'hifigan_csmsc', + 'hifigan_ljspeech', + 'hifigan_aishell3', + 'hifigan_vctk', + 'wavernn_csmsc', + ], + help='Choose vocoder type of tts task.') parser.add_argument( - "--waveflow-config", type=str, help="waveflow config file.") - # not normalize when training waveflow + '--voc_config', type=str, default=None, help='Config of voc.') parser.add_argument( - "--waveflow-checkpoint", type=str, help="waveflow checkpoint to load.") + '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( - "--phones-dict", + "--voc_stat", type=str, - default="phone_id_map.txt", - help="phone vocabulary file.") + default=None, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other parser.add_argument( - "--text", + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + '--lang', type=str, - help="text to synthesize, a 'utt_id sentence' pair per line.") - parser.add_argument("--output-dir", type=str, help="output dir.") + default='zh', + help='Choose model language. zh or en or mix') parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line.") + parser.add_argument("--output_dir", type=str, help="output dir.") args = parser.parse_args() @@ -149,16 +177,16 @@ def main(): with open(args.transformer_tts_config) as f: transformer_tts_config = CfgNode(yaml.safe_load(f)) - with open(args.waveflow_config) as f: - waveflow_config = CfgNode(yaml.safe_load(f)) + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) print("========Args========") print(yaml.safe_dump(vars(args))) print("========Config========") print(transformer_tts_config) - print(waveflow_config) + print(voc_config) - evaluate(args, transformer_tts_config, waveflow_config) + evaluate(args, transformer_tts_config, voc_config) if __name__ == "__main__": diff --git a/paddlespeech/t2s/modules/transformer/repeat.py b/paddlespeech/t2s/modules/transformer/repeat.py index 43d11e9f9..d7f45e676 100644 --- a/paddlespeech/t2s/modules/transformer/repeat.py +++ b/paddlespeech/t2s/modules/transformer/repeat.py @@ -38,4 +38,4 @@ def repeat(N, fn): Returns: MultiSequential: Repeated model instance. """ - return MultiSequential(* [fn(n) for n in range(N)]) + return MultiSequential(*[fn(n) for n in range(N)]) diff --git a/setup.py b/setup.py index 1cc82fa76..b9e1f7c73 100644 --- a/setup.py +++ b/setup.py @@ -76,12 +76,7 @@ base = [ "pybind11", ] -server = [ - "fastapi", - "uvicorn", - "pattern_singleton", - "websockets" -] +server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"] requirements = { "install": diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py index 4426d1be8..386b4d9b8 100755 --- a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py +++ b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py @@ -496,6 +496,10 @@ class SymbolicShapeInference: + + + + 'Attention', 'BiasGelu', \ 'EmbedLayerNormalization', \ 'FastGelu', 'Gelu', 'LayerNormalization', \ @@ -514,8 +518,8 @@ class SymbolicShapeInference: if (get_opset(self.out_mp_) >= 9) and node.op_type in ['Unsqueeze']: initializers = [ self.initializers_[name] for name in node.input - if (name in self.initializers_ and - name not in self.graph_inputs_) + if (name in self.initializers_ and name not in + self.graph_inputs_) ] # run single node inference with self.known_vi_ shapes @@ -601,8 +605,8 @@ class SymbolicShapeInference: for o in symbolic_shape_inference.out_mp_.graph.output ] subgraph_new_symbolic_dims = set([ - d for s in subgraph_shapes if s for d in s - if type(d) == str and not d in self.symbolic_dims_ + d for s in subgraph_shapes + if s for d in s if type(d) == str and not d in self.symbolic_dims_ ]) new_dims = {} for d in subgraph_new_symbolic_dims: @@ -729,8 +733,9 @@ class SymbolicShapeInference: for d, s in zip(sympy_shape[-rank:], strides) ] total_pads = [ - max(0, (k - s) if r == 0 else (k - r)) for k, s, r in - zip(effective_kernel_shape, strides, residual) + max(0, (k - s) if r == 0 else (k - r)) + for k, s, r in zip(effective_kernel_shape, strides, + residual) ] except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational total_pads = [ @@ -1276,8 +1281,9 @@ class SymbolicShapeInference: if pads is not None: assert len(pads) == 2 * rank new_sympy_shape = [ - d + pad_up + pad_down for d, pad_up, pad_down in - zip(sympy_shape, pads[:rank], pads[rank:]) + d + pad_up + pad_down + for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[ + rank:]) ] self._update_computed_dims(new_sympy_shape) else: @@ -1590,8 +1596,8 @@ class SymbolicShapeInference: scales = list(scales) new_sympy_shape = [ sympy.simplify(sympy.floor(d * (end - start) * scale)) - for d, start, end, scale in - zip(input_sympy_shape, roi_start, roi_end, scales) + for d, start, end, scale in zip(input_sympy_shape, + roi_start, roi_end, scales) ] self._update_computed_dims(new_sympy_shape) else: @@ -2204,8 +2210,9 @@ class SymbolicShapeInference: # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate sorted_nodes = [] sorted_known_vi = set([ - i.name for i in list(self.out_mp_.graph.input) + - list(self.out_mp_.graph.initializer) + i.name + for i in list(self.out_mp_.graph.input) + list( + self.out_mp_.graph.initializer) ]) if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): # Loop/Scan will have some graph output in graph inputs, so don't do topological sort