From deab2c75ba6ff6583ef71019cc5ed485f091f421 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=90=95=E5=BF=97=E8=BD=A9?=
<8252801+lv-zhixuan@user.noreply.gitee.com>
Date: Sat, 24 Sep 2022 11:18:02 +0800
Subject: [PATCH] =?UTF-8?q?=E8=AE=AD=E7=BB=83=E5=BE=97=E5=88=B024k?=
=?UTF-8?q?=E7=9A=84transformer=E5=A3=B0=E7=A0=81=E5=99=A8=EF=BC=8C?=
=?UTF-8?q?=E5=B9=B6=E6=8C=89=E6=A0=87=E5=87=86=E5=8C=96=E6=A0=BC=E5=BC=8F?=
=?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86preprocess=E5=92=8Csynthesize?=
=?UTF-8?q?=E4=BB=A3=E7=A0=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.idea/.gitignore | 3 +
.idea/PaddleSpeech.iml | 12 +
.../inspectionProfiles/profiles_settings.xml | 6 +
.idea/misc.xml | 4 +
.idea/modules.xml | 8 +
.idea/vcs.xml | 6 +
demos/audio_searching/src/operations/load.py | 5 +-
demos/speech_web/API.md | 2 +-
demos/speech_web/speech_server/main.py | 163 +++++-----
.../speech_web/speech_server/requirements.txt | 11 +-
.../speech_server/src/AudioManeger.py | 88 +++---
.../speech_server/src/SpeechBase/asr.py | 17 +-
.../speech_server/src/SpeechBase/nlp.py | 18 +-
.../src/SpeechBase/sql_helper.py | 49 +--
.../speech_server/src/SpeechBase/tts.py | 93 +++---
.../speech_server/src/SpeechBase/vpr.py | 55 ++--
.../src/SpeechBase/vpr_encode.py | 9 +-
.../speech_server/src/WebsocketManeger.py | 3 +-
demos/speech_web/speech_server/src/robot.py | 45 +--
demos/speech_web/speech_server/src/util.py | 17 +-
.../local/rtf_from_log.py | 2 +-
docs/requirements.txt | 35 ++-
docs/source/conf.py | 5 +-
examples/csmsc/tts1/README.md | 16 +-
examples/csmsc/tts1/conf/default.yaml | 19 +-
examples/csmsc/tts1/local/preprocess.sh | 20 +-
examples/csmsc/tts1/local/synthesize.sh | 100 +++++-
examples/csmsc/tts1/local/synthesize_e2e.sh | 114 ++++++-
examples/ernie_sat/local/inference.py | 6 +-
examples/ernie_sat/local/inference_new.py | 8 +-
paddlespeech/__init__.py | 2 -
paddlespeech/audio/__init__.py | 6 +-
paddlespeech/audio/streamdata/__init__.py | 125 ++++----
paddlespeech/audio/streamdata/autodecode.py | 19 +-
paddlespeech/audio/streamdata/cache.py | 64 ++--
paddlespeech/audio/streamdata/compat.py | 65 ++--
.../audio/streamdata/extradatasets.py | 6 +-
paddlespeech/audio/streamdata/filters.py | 247 +++++++++------
paddlespeech/audio/streamdata/gopen.py | 62 ++--
paddlespeech/audio/streamdata/handlers.py | 5 +-
paddlespeech/audio/streamdata/mix.py | 17 +-
paddlespeech/audio/streamdata/paddle_utils.py | 3 +-
paddlespeech/audio/streamdata/pipeline.py | 15 +-
paddlespeech/audio/streamdata/shardlists.py | 75 +++--
paddlespeech/audio/streamdata/tariterators.py | 81 ++---
paddlespeech/audio/streamdata/utils.py | 31 +-
paddlespeech/audio/streamdata/writer.py | 77 ++---
paddlespeech/audio/text/text_featurizer.py | 2 +-
paddlespeech/audio/transform/perturb.py | 11 +-
paddlespeech/audio/transform/spec_augment.py | 1 +
paddlespeech/cli/executor.py | 2 +-
paddlespeech/s2t/__init__.py | 1 +
paddlespeech/s2t/exps/u2/model.py | 23 +-
paddlespeech/s2t/exps/u2_kaldi/model.py | 26 +-
paddlespeech/s2t/exps/u2_st/model.py | 19 +-
paddlespeech/s2t/io/dataloader.py | 145 +++++----
paddlespeech/s2t/io/sampler.py | 2 +-
paddlespeech/s2t/models/u2_st/u2_st.py | 12 +-
paddlespeech/s2t/modules/align.py | 39 ++-
paddlespeech/s2t/modules/attention.py | 23 +-
.../s2t/modules/conformer_convolution.py | 17 +-
paddlespeech/s2t/modules/encoder.py | 29 +-
paddlespeech/s2t/modules/initializer.py | 1 +
.../server/bin/paddlespeech_server.py | 2 +-
.../server/engine/asr/online/ctc_endpoint.py | 6 +-
.../engine/asr/online/onnx/asr_engine.py | 2 +-
.../asr/online/paddleinference/asr_engine.py | 2 +-
.../engine/asr/online/python/asr_engine.py | 13 +-
paddlespeech/t2s/exps/ernie_sat/align.py | 9 +-
.../t2s/exps/ernie_sat/synthesize_e2e.py | 68 ++---
paddlespeech/t2s/exps/ernie_sat/utils.py | 11 +-
paddlespeech/t2s/exps/tacotron2/normalize.py | 2 +-
.../exps/transformer_tts/preprocess_new.py | 284 ++++++------------
.../t2s/exps/transformer_tts/synthesize.py | 52 ++--
.../exps/transformer_tts/synthesize_e2e.py | 76 +++--
.../t2s/modules/transformer/repeat.py | 2 +-
setup.py | 7 +-
.../ds2_ol/onnx/local/onnx_infer_shape.py | 31 +-
78 files changed, 1551 insertions(+), 1208 deletions(-)
create mode 100644 .idea/.gitignore
create mode 100644 .idea/PaddleSpeech.iml
create mode 100644 .idea/inspectionProfiles/profiles_settings.xml
create mode 100644 .idea/misc.xml
create mode 100644 .idea/modules.xml
create mode 100644 .idea/vcs.xml
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 000000000..26d33521a
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/PaddleSpeech.iml b/.idea/PaddleSpeech.iml
new file mode 100644
index 000000000..8a05c6ed5
--- /dev/null
+++ b/.idea/PaddleSpeech.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 000000000..105ce2da2
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 000000000..a2e120dcc
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 000000000..d4f59faea
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 000000000..94a25f7f4
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/demos/audio_searching/src/operations/load.py b/demos/audio_searching/src/operations/load.py
index 0d9edb784..d1ea00576 100644
--- a/demos/audio_searching/src/operations/load.py
+++ b/demos/audio_searching/src/operations/load.py
@@ -26,8 +26,9 @@ def get_audios(path):
"""
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [
- item for sublist in [[os.path.join(dir, file) for file in files]
- for dir, _, files in list(os.walk(path))]
+ item
+ for sublist in [[os.path.join(dir, file) for file in files]
+ for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats
]
diff --git a/demos/speech_web/API.md b/demos/speech_web/API.md
index c51446749..f66ec138e 100644
--- a/demos/speech_web/API.md
+++ b/demos/speech_web/API.md
@@ -401,4 +401,4 @@ curl -X 'GET' \
"code": 0,
"result":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
"message": "ok"
-```
\ No newline at end of file
+```
diff --git a/demos/speech_web/speech_server/main.py b/demos/speech_web/speech_server/main.py
index b10176670..8d1de2abe 100644
--- a/demos/speech_web/speech_server/main.py
+++ b/demos/speech_web/speech_server/main.py
@@ -3,48 +3,53 @@
# 2. 接收录音音频,返回识别结果
# 3. 接收ASR识别结果,返回NLP对话结果
# 4. 接收NLP对话结果,返回TTS音频
-
+import argparse
import base64
-import yaml
-import os
-import json
import datetime
+import json
+import os
+from typing import List
+from typing import Optional
+
+import aiofiles
import librosa
-import soundfile as sf
import numpy as np
-import argparse
+import soundfile as sf
import uvicorn
-import aiofiles
-from typing import Optional, List
-from pydantic import BaseModel
-from fastapi import FastAPI, Header, File, UploadFile, Form, Cookie, WebSocket, WebSocketDisconnect
+import yaml
+from fastapi import Cookie
+from fastapi import FastAPI
+from fastapi import File
+from fastapi import Form
+from fastapi import Header
+from fastapi import UploadFile
+from fastapi import WebSocket
+from fastapi import WebSocketDisconnect
from fastapi.responses import StreamingResponse
-from starlette.responses import FileResponse
-from starlette.middleware.cors import CORSMiddleware
-from starlette.requests import Request
-from starlette.websockets import WebSocketState as WebSocketState
-
+from pydantic import BaseModel
from src.AudioManeger import AudioMannger
-from src.util import *
from src.robot import Robot
-from src.WebsocketManeger import ConnectionManager
from src.SpeechBase.vpr import VPR
+from src.util import *
+from src.WebsocketManeger import ConnectionManager
+from starlette.middleware.cors import CORSMiddleware
+from starlette.requests import Request
+from starlette.responses import FileResponse
+from starlette.websockets import WebSocketState as WebSocketState
from paddlespeech.server.engine.asr.online.python.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.utils.audio_process import float2pcm
-
# 解析配置
-parser = argparse.ArgumentParser(
- prog='PaddleSpeechDemo', add_help=True)
+parser = argparse.ArgumentParser(prog='PaddleSpeechDemo', add_help=True)
parser.add_argument(
- "--port",
- action="store",
- type=int,
- help="port of the app",
- default=8010,
- required=False)
+ "--port",
+ action="store",
+ type=int,
+ help="port of the app",
+ default=8010,
+ required=False)
args = parser.parse_args()
port = args.port
@@ -60,39 +65,41 @@ ie_model_path = "source/model"
UPLOAD_PATH = "source/vpr"
WAV_PATH = "source/wav"
-
-base_sources = [
- UPLOAD_PATH, WAV_PATH
-]
+base_sources = [UPLOAD_PATH, WAV_PATH]
for path in base_sources:
os.makedirs(path, exist_ok=True)
-
# 初始化
app = FastAPI()
-chatbot = Robot(asr_config, tts_config, asr_init_path, ie_model_path=ie_model_path)
+chatbot = Robot(
+ asr_config, tts_config, asr_init_path, ie_model_path=ie_model_path)
manager = ConnectionManager()
aumanager = AudioMannger(chatbot)
aumanager.init()
-vpr = VPR(db_path, dim = 192, top_k = 5)
+vpr = VPR(db_path, dim=192, top_k=5)
+
# 服务配置
class NlpBase(BaseModel):
chat: str
+
class TtsBase(BaseModel):
- text: str
+ text: str
+
class Audios:
def __init__(self) -> None:
self.audios = b""
+
audios = Audios()
######################################################################
########################### ASR 服务 #################################
#####################################################################
+
# 接收文件,返回ASR结果
# 上传文件
@app.post("/asr/offline")
@@ -101,7 +108,8 @@ async def speech2textOffline(files: List[UploadFile]):
asr_res = ""
for file in files[:1]:
# 生成时间戳
- now_name = "asr_offline_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
+ now_name = "asr_offline_" + datetime.datetime.strftime(
+ datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
out_file_path = os.path.join(WAV_PATH, now_name)
async with aiofiles.open(out_file_path, 'wb') as out_file:
content = await file.read() # async read
@@ -111,9 +119,10 @@ async def speech2textOffline(files: List[UploadFile]):
asr_res = chatbot.speech2text(out_file_path)
return SuccessRequest(result=asr_res)
# else:
- # return ErrorRequest(message="文件不是.wav格式")
+ # return ErrorRequest(message="文件不是.wav格式")
return ErrorRequest(message="上传文件为空")
+
# 接收文件,同时将wav强制转成16k, int16类型
@app.post("/asr/offlinefile")
async def speech2textOfflineFile(files: List[UploadFile]):
@@ -121,7 +130,8 @@ async def speech2textOfflineFile(files: List[UploadFile]):
asr_res = ""
for file in files[:1]:
# 生成时间戳
- now_name = "asr_offline_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
+ now_name = "asr_offline_" + datetime.datetime.strftime(
+ datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
out_file_path = os.path.join(WAV_PATH, now_name)
async with aiofiles.open(out_file_path, 'wb') as out_file:
content = await file.read() # async read
@@ -132,22 +142,18 @@ async def speech2textOfflineFile(files: List[UploadFile]):
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
wav_base64 = base64.b64encode(wav_bytes).decode('utf8')
-
+
# 将文件重新写入
now_name = now_name[:-4] + "_16k" + ".wav"
out_file_path = os.path.join(WAV_PATH, now_name)
- sf.write(out_file_path,wav,16000)
+ sf.write(out_file_path, wav, 16000)
# 返回ASR识别结果
asr_res = chatbot.speech2text(out_file_path)
- response_res = {
- "asr_result": asr_res,
- "wav_base64": wav_base64
- }
+ response_res = {"asr_result": asr_res, "wav_base64": wav_base64}
return SuccessRequest(result=response_res)
-
- return ErrorRequest(message="上传文件为空")
+ return ErrorRequest(message="上传文件为空")
# 流式接收测试
@@ -161,15 +167,17 @@ async def speech2textOnlineRecive(files: List[UploadFile]):
print(f"audios长度变化: {len(audios.audios)}")
return SuccessRequest(message="接收成功")
+
# 采集环境噪音大小
@app.post("/asr/collectEnv")
async def collectEnv(files: List[UploadFile]):
- for file in files[:1]:
+ for file in files[:1]:
content = await file.read() # async read
# 初始化, wav 前44字节是头部信息
aumanager.compute_env_volume(content[44:])
vad_ = aumanager.vad_threshold
- return SuccessRequest(result=vad_,message="采集环境噪音成功")
+ return SuccessRequest(result=vad_, message="采集环境噪音成功")
+
# 停止录音
@app.get("/asr/stopRecord")
@@ -179,6 +187,7 @@ async def stopRecord():
print("Online录音暂停")
return SuccessRequest(message="停止成功")
+
# 恢复录音
@app.get("/asr/resumeRecord")
async def resumeRecord():
@@ -210,7 +219,7 @@ async def websocket_endpoint(websocket: WebSocket):
# print(f"用户-{user}-离开")
-# Online识别的ASR
+ # Online识别的ASR
@app.websocket('/ws/asr/onlineStream')
async def websocket_endpoint(websocket: WebSocket):
"""PaddleSpeech Online ASR Server api
@@ -298,12 +307,14 @@ async def websocket_endpoint(websocket: WebSocket):
except WebSocketDisconnect:
pass
+
######################################################################
########################### NLP 服务 #################################
#####################################################################
+
@app.post("/nlp/chat")
-async def chatOffline(nlp_base:NlpBase):
+async def chatOffline(nlp_base: NlpBase):
chat = nlp_base.chat
if not chat:
return ErrorRequest(message="传入文本为空")
@@ -311,8 +322,9 @@ async def chatOffline(nlp_base:NlpBase):
res = chatbot.chat(chat)
return SuccessRequest(result=res)
+
@app.post("/nlp/ie")
-async def ieOffline(nlp_base:NlpBase):
+async def ieOffline(nlp_base: NlpBase):
nlp_text = nlp_base.chat
if not nlp_text:
return ErrorRequest(message="传入文本为空")
@@ -320,17 +332,20 @@ async def ieOffline(nlp_base:NlpBase):
res = chatbot.ie(nlp_text)
return SuccessRequest(result=res)
+
######################################################################
########################### TTS 服务 #################################
#####################################################################
+
@app.post("/tts/offline")
-async def text2speechOffline(tts_base:TtsBase):
+async def text2speechOffline(tts_base: TtsBase):
text = tts_base.text
if not text:
return ErrorRequest(message="文本为空")
else:
- now_name = "tts_"+ datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
+ now_name = "tts_" + datetime.datetime.strftime(
+ datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
out_file_path = os.path.join(WAV_PATH, now_name)
# 保存为文件,再转成base64传输
chatbot.text2speech(text, outpath=out_file_path)
@@ -339,12 +354,14 @@ async def text2speechOffline(tts_base:TtsBase):
base_str = base64.b64encode(data_bin)
return SuccessRequest(result=base_str)
+
# http流式TTS
@app.post("/tts/online")
async def stream_tts(request_body: TtsBase):
text = request_body.text
return StreamingResponse(chatbot.text2speechStreamBytes(text=text))
+
# ws流式TTS
@app.websocket("/ws/tts/online")
async def stream_ttsWS(websocket: WebSocket):
@@ -356,17 +373,11 @@ async def stream_ttsWS(websocket: WebSocket):
if text:
for sub_wav in chatbot.text2speechStream(text=text):
# print("发送sub wav: ", len(sub_wav))
- res = {
- "wav": sub_wav,
- "done": False
- }
+ res = {"wav": sub_wav, "done": False}
await websocket.send_json(res)
-
+
# 输送结束
- res = {
- "wav": sub_wav,
- "done": True
- }
+ res = {"wav": sub_wav, "done": True}
await websocket.send_json(res)
# manager.disconnect(websocket)
@@ -396,8 +407,9 @@ async def vpr_enroll(table_name: str=None,
return {'status': False, 'msg': "spk_id can not be None"}
# Save the upload data to server.
content = await audio.read()
- now_name = "vpr_enroll_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
- audio_path = os.path.join(UPLOAD_PATH, now_name)
+ now_name = "vpr_enroll_" + datetime.datetime.strftime(
+ datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
+ audio_path = os.path.join(UPLOAD_PATH, now_name)
with open(audio_path, "wb+") as f:
f.write(content)
@@ -413,12 +425,13 @@ async def vpr_recog(request: Request,
audio: UploadFile=File(...)):
# Voice print recognition online
# try:
- # Save the upload data to server.
+ # Save the upload data to server.
content = await audio.read()
- now_name = "vpr_query_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
- query_audio_path = os.path.join(UPLOAD_PATH, now_name)
+ now_name = "vpr_query_" + datetime.datetime.strftime(
+ datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav"
+ query_audio_path = os.path.join(UPLOAD_PATH, now_name)
with open(query_audio_path, "wb+") as f:
- f.write(content)
+ f.write(content)
spk_ids, paths, scores = vpr.do_search_vpr(query_audio_path)
res = dict(zip(spk_ids, zip(paths, scores)))
@@ -426,7 +439,9 @@ async def vpr_recog(request: Request,
res = sorted(res.items(), key=lambda item: item[1][1], reverse=True)
return res
# except Exception as e:
- # return {'status': False, 'msg': e}, 400
+
+
+# return {'status': False, 'msg': e}, 400
@app.post('/vpr/del')
@@ -460,17 +475,18 @@ async def vpr_database64(vprId: int):
return {'status': False, 'msg': "vpr_id can not be None"}
audio_path = vpr.do_get_wav(vprId)
# 返回base64
-
+
# 将文件转成16k, 16bit类型的wav文件
wav, sr = librosa.load(audio_path, sr=16000)
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
wav_base64 = base64.b64encode(wav_bytes).decode('utf8')
-
+
return SuccessRequest(result=wav_base64)
except Exception as e:
return {'status': False, 'msg': e}, 400
+
@app.get('/vpr/data')
async def vpr_data(vprId: int):
# Get the audio file from path by spk_id in MySQL
@@ -482,11 +498,6 @@ async def vpr_data(vprId: int):
except Exception as e:
return {'status': False, 'msg': e}, 400
+
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=port)
-
-
-
-
-
-
diff --git a/demos/speech_web/speech_server/requirements.txt b/demos/speech_web/speech_server/requirements.txt
index 7e7bd1680..607f0d4d0 100644
--- a/demos/speech_web/speech_server/requirements.txt
+++ b/demos/speech_web/speech_server/requirements.txt
@@ -1,14 +1,13 @@
aiofiles
+faiss-cpu
fastapi
librosa
numpy
+paddlenlp
+paddlepaddle
+paddlespeech
pydantic
-scikit_learn
+python-multipartscikit_learn
SoundFile
starlette
uvicorn
-paddlepaddle
-paddlespeech
-paddlenlp
-faiss-cpu
-python-multipart
\ No newline at end of file
diff --git a/demos/speech_web/speech_server/src/AudioManeger.py b/demos/speech_web/speech_server/src/AudioManeger.py
index 0deb03699..8daa07a32 100644
--- a/demos/speech_web/speech_server/src/AudioManeger.py
+++ b/demos/speech_web/speech_server/src/AudioManeger.py
@@ -1,15 +1,22 @@
+import datetime
import imp
-from queue import Queue
-import numpy as np
import os
-import wave
import random
-import datetime
+import wave
+from queue import Queue
+
+import numpy as np
+
from .util import randName
class AudioMannger:
- def __init__(self, robot, frame_length=160, frame=10, data_width=2, vad_default = 300):
+ def __init__(self,
+ robot,
+ frame_length=160,
+ frame=10,
+ data_width=2,
+ vad_default=300):
# 二进制 pcm 流
self.audios = b''
self.asr_result = ""
@@ -20,8 +27,9 @@ class AudioMannger:
os.makedirs(self.file_dir, exist_ok=True)
self.vad_deafult = vad_default
self.vad_threshold = vad_default
- self.vad_threshold_path = os.path.join(self.file_dir, "vad_threshold.npy")
-
+ self.vad_threshold_path = os.path.join(self.file_dir,
+ "vad_threshold.npy")
+
# 10ms 一帧
self.frame_length = frame_length
# 10帧,检测一次 vad
@@ -30,67 +38,64 @@ class AudioMannger:
self.data_width = data_width
# window
self.window_length = frame_length * frame * data_width
-
+
# 是否开始录音
self.on_asr = False
- self.silence_cnt = 0
+ self.silence_cnt = 0
self.max_silence_cnt = 4
self.is_pause = False # 录音暂停与恢复
-
-
-
+
def init(self):
if os.path.exists(self.vad_threshold_path):
# 平均响度文件存在
self.vad_threshold = np.load(self.vad_threshold_path)
-
-
+
def clear_audio(self):
# 清空 pcm 累积片段与 asr 识别结果
self.audios = b''
-
+
def clear_asr(self):
self.asr_result = ""
-
-
+
def compute_chunk_volume(self, start_index, pcm_bins):
# 根据帧长计算能量平均值
- pcm_bin = pcm_bins[start_index: start_index + self.window_length]
+ pcm_bin = pcm_bins[start_index:start_index + self.window_length]
# 转成 numpy
pcm_np = np.frombuffer(pcm_bin, np.int16)
# 归一化 + 计算响度
x = pcm_np.astype(np.float32)
x = np.abs(x)
- return np.mean(x)
-
-
+ return np.mean(x)
+
def is_speech(self, start_index, pcm_bins):
# 检查是否没
if start_index > len(pcm_bins):
return False
# 检查从这个 start 开始是否为静音帧
- energy = self.compute_chunk_volume(start_index=start_index, pcm_bins=pcm_bins)
+ energy = self.compute_chunk_volume(
+ start_index=start_index, pcm_bins=pcm_bins)
# print(energy)
if energy > self.vad_threshold:
return True
else:
return False
-
+
def compute_env_volume(self, pcm_bins):
max_energy = 0
start = 0
while start < len(pcm_bins):
- energy = self.compute_chunk_volume(start_index=start, pcm_bins=pcm_bins)
+ energy = self.compute_chunk_volume(
+ start_index=start, pcm_bins=pcm_bins)
if energy > max_energy:
max_energy = energy
start += self.window_length
self.vad_threshold = max_energy + 100 if max_energy > self.vad_deafult else self.vad_deafult
-
+
# 保存成文件
np.save(self.vad_threshold_path, self.vad_threshold)
print(f"vad 阈值大小: {self.vad_threshold}")
print(f"环境采样保存: {os.path.realpath(self.vad_threshold_path)}")
-
+
def stream_asr(self, pcm_bin):
# 先把 pcm_bin 送进去做端点检测
start = 0
@@ -99,7 +104,7 @@ class AudioMannger:
self.on_asr = True
self.silence_cnt = 0
print("录音中")
- self.audios += pcm_bin[ start : start + self.window_length]
+ self.audios += pcm_bin[start:start + self.window_length]
else:
if self.on_asr:
self.silence_cnt += 1
@@ -110,41 +115,42 @@ class AudioMannger:
print("录音停止")
# audios 保存为 wav, 送入 ASR
if len(self.audios) > 2 * 16000:
- file_path = os.path.join(self.file_dir, "asr_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav")
+ file_path = os.path.join(
+ self.file_dir,
+ "asr_" + datetime.datetime.strftime(
+ datetime.datetime.now(),
+ '%Y%m%d%H%M%S') + randName() + ".wav")
self.save_audio(file_path=file_path)
self.asr_result = self.robot.speech2text(file_path)
self.clear_audio()
- return self.asr_result
+ return self.asr_result
else:
# 正常接收
print("录音中 静音")
- self.audios += pcm_bin[ start : start + self.window_length]
+ self.audios += pcm_bin[start:start + self.window_length]
start += self.window_length
return ""
-
+
def save_audio(self, file_path):
print("保存音频")
- wf = wave.open(file_path, 'wb') # 创建一个音频文件,名字为“01.wav"
- wf.setnchannels(1) # 设置声道数为2
- wf.setsampwidth(2) # 设置采样深度为
- wf.setframerate(16000) # 设置采样率为16000
+ wf = wave.open(file_path, 'wb') # 创建一个音频文件,名字为“01.wav"
+ wf.setnchannels(1) # 设置声道数为2
+ wf.setsampwidth(2) # 设置采样深度为
+ wf.setframerate(16000) # 设置采样率为16000
# 将数据写入创建的音频文件
wf.writeframes(self.audios)
# 写完后将文件关闭
wf.close()
-
+
def end(self):
# audios 保存为 wav, 送入 ASR
file_path = os.path.join(self.file_dir, "asr.wav")
self.save_audio(file_path=file_path)
return self.robot.speech2text(file_path)
-
+
def stop(self):
self.is_pause = True
self.audios = b''
-
+
def resume(self):
self.is_pause = False
-
-
-
\ No newline at end of file
diff --git a/demos/speech_web/speech_server/src/SpeechBase/asr.py b/demos/speech_web/speech_server/src/SpeechBase/asr.py
index 8d4c0cffc..3b9e05c8b 100644
--- a/demos/speech_web/speech_server/src/SpeechBase/asr.py
+++ b/demos/speech_web/speech_server/src/SpeechBase/asr.py
@@ -1,13 +1,15 @@
from re import sub
+
+import librosa
import numpy as np
import paddle
-import librosa
import soundfile
from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine
from paddlespeech.server.engine.asr.online.python.asr_engine import PaddleASRConnectionHanddler
from paddlespeech.server.utils.config import get_config
+
def readWave(samples):
x_len = len(samples)
@@ -31,20 +33,23 @@ def readWave(samples):
class ASR:
- def __init__(self, config_path, ) -> None:
+ def __init__(
+ self,
+ config_path, ) -> None:
self.config = get_config(config_path)['asr_online']
self.engine = ASREngine()
self.engine.init(self.config)
self.connection_handler = PaddleASRConnectionHanddler(self.engine)
-
+
def offlineASR(self, samples, sample_rate=16000):
- x_chunk, x_chunk_lens = self.engine.preprocess(samples=samples, sample_rate=sample_rate)
+ x_chunk, x_chunk_lens = self.engine.preprocess(
+ samples=samples, sample_rate=sample_rate)
self.engine.run(x_chunk, x_chunk_lens)
result = self.engine.postprocess()
self.engine.reset()
return result
- def onlineASR(self, samples:bytes=None, is_finished=False):
+ def onlineASR(self, samples: bytes=None, is_finished=False):
if not is_finished:
# 流式开始
self.connection_handler.extract_feat(samples)
@@ -58,5 +63,3 @@ class ASR:
asr_results = self.connection_handler.get_result()
self.connection_handler.reset()
return asr_results
-
-
\ No newline at end of file
diff --git a/demos/speech_web/speech_server/src/SpeechBase/nlp.py b/demos/speech_web/speech_server/src/SpeechBase/nlp.py
index 4ece63256..b642a51d6 100644
--- a/demos/speech_web/speech_server/src/SpeechBase/nlp.py
+++ b/demos/speech_web/speech_server/src/SpeechBase/nlp.py
@@ -1,23 +1,23 @@
from paddlenlp import Taskflow
+
class NLP:
def __init__(self, ie_model_path=None):
schema = ["时间", "出发地", "目的地", "费用"]
if ie_model_path:
- self.ie_model = Taskflow("information_extraction",
- schema=schema, task_path=ie_model_path)
+ self.ie_model = Taskflow(
+ "information_extraction",
+ schema=schema,
+ task_path=ie_model_path)
else:
- self.ie_model = Taskflow("information_extraction",
- schema=schema)
-
+ self.ie_model = Taskflow("information_extraction", schema=schema)
+
self.dialogue_model = Taskflow("dialogue")
-
+
def chat(self, text):
result = self.dialogue_model([text])
return result[0]
-
+
def ie(self, text):
result = self.ie_model(text)
return result
-
-
\ No newline at end of file
diff --git a/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py b/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py
index 6937def58..628dbc4e8 100644
--- a/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py
+++ b/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py
@@ -1,18 +1,20 @@
import base64
-import sqlite3
import os
+import sqlite3
+
import numpy as np
from pkg_resources import resource_stream
-def dict_factory(cursor, row):
- d = {}
- for idx, col in enumerate(cursor.description):
- d[col[0]] = row[idx]
- return d
+def dict_factory(cursor, row):
+ d = {}
+ for idx, col in enumerate(cursor.description):
+ d[col[0]] = row[idx]
+ return d
+
class DataBase(object):
- def __init__(self, db_path:str):
+ def __init__(self, db_path: str):
db_path = os.path.realpath(db_path)
if os.path.exists(db_path):
@@ -21,12 +23,12 @@ class DataBase(object):
db_path_dir = os.path.dirname(db_path)
os.makedirs(db_path_dir, exist_ok=True)
self.db_path = db_path
-
+
self.conn = sqlite3.connect(self.db_path)
self.conn.row_factory = dict_factory
self.cursor = self.conn.cursor()
self.init_database()
-
+
def init_database(self):
"""
初始化数据库, 若表不存在则创建
@@ -41,12 +43,12 @@ class DataBase(object):
"""
self.cursor.execute(sql)
self.conn.commit()
-
+
def execute_base(self, sql, data_dict):
self.cursor.execute(sql, data_dict)
self.conn.commit()
-
- def insert_one(self, username, vector_base64:str, wav_path):
+
+ def insert_one(self, username, vector_base64: str, wav_path):
if not os.path.exists(wav_path):
return None, "wav not exists"
else:
@@ -55,6 +57,7 @@ class DataBase(object):
vprtable (username, vector, wavpath)
values (?, ?, ?)
"""
+
try:
self.cursor.execute(sql, (username, vector_base64, wav_path))
self.conn.commit()
@@ -63,25 +66,27 @@ class DataBase(object):
except Exception as e:
print(e)
return None, e
-
+
def select_all(self):
sql = """
SELECT * from vprtable
"""
result = self.cursor.execute(sql).fetchall()
return result
-
+
def select_by_id(self, vpr_id):
sql = f"""
SELECT * from vprtable WHERE `id` = {vpr_id}
"""
+
result = self.cursor.execute(sql).fetchall()
return result
-
+
def select_by_username(self, username):
sql = f"""
SELECT * from vprtable WHERE `username` = '{username}'
"""
+
result = self.cursor.execute(sql).fetchall()
return result
@@ -89,28 +94,30 @@ class DataBase(object):
sql = f"""
DELETE from vprtable WHERE `username`='{username}'
"""
+
self.cursor.execute(sql)
self.conn.commit()
-
+
def drop_all(self):
sql = f"""
DELETE from vprtable
"""
+
self.cursor.execute(sql)
self.conn.commit()
-
+
def drop_table(self):
sql = f"""
DROP TABLE vprtable
"""
+
self.cursor.execute(sql)
self.conn.commit()
-
- def encode_vector(self, vector:np.ndarray):
+
+ def encode_vector(self, vector: np.ndarray):
return base64.b64encode(vector).decode('utf8')
-
+
def decode_vector(self, vector_base64, dtype=np.float32):
b = base64.b64decode(vector_base64)
vc = np.frombuffer(b, dtype=dtype)
return vc
-
\ No newline at end of file
diff --git a/demos/speech_web/speech_server/src/SpeechBase/tts.py b/demos/speech_web/speech_server/src/SpeechBase/tts.py
index d5ba0c802..030f8eef0 100644
--- a/demos/speech_web/speech_server/src/SpeechBase/tts.py
+++ b/demos/speech_web/speech_server/src/SpeechBase/tts.py
@@ -5,18 +5,20 @@
# 2. 加载模型
# 3. 端到端推理
# 4. 流式推理
-
import base64
-import math
import logging
+import math
+
import numpy as np
-from paddlespeech.server.utils.onnx_infer import get_sess
-from paddlespeech.t2s.frontend.zh_frontend import Frontend
-from paddlespeech.server.utils.util import denorm, get_chunks
+
+from paddlespeech.server.engine.tts.online.onnx.tts_engine import TTSEngine
from paddlespeech.server.utils.audio_process import float2pcm
from paddlespeech.server.utils.config import get_config
+from paddlespeech.server.utils.onnx_infer import get_sess
+from paddlespeech.server.utils.util import denorm
+from paddlespeech.server.utils.util import get_chunks
+from paddlespeech.t2s.frontend.zh_frontend import Frontend
-from paddlespeech.server.engine.tts.online.onnx.tts_engine import TTSEngine
class TTS:
def __init__(self, config_path):
@@ -26,12 +28,12 @@ class TTS:
self.engine.init(self.config)
self.executor = self.engine.executor
#self.engine.warm_up()
-
+
# 前端初始化
self.frontend = Frontend(
- phone_vocab_path=self.engine.executor.phones_dict,
- tone_vocab_path=None)
-
+ phone_vocab_path=self.engine.executor.phones_dict,
+ tone_vocab_path=None)
+
def depadding(self, data, chunk_num, chunk_id, block, pad, upsample):
"""
Streaming inference removes the result of pad inference
@@ -48,39 +50,37 @@ class TTS:
data = data[front_pad * upsample:(front_pad + block) * upsample]
return data
-
+
def offlineTTS(self, text):
get_tone_ids = False
merge_sentences = False
-
+
input_ids = self.frontend.get_input_ids(
- text,
- merge_sentences=merge_sentences,
- get_tone_ids=get_tone_ids)
+ text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
wav_list = []
for i in range(len(phone_ids)):
orig_hs = self.engine.executor.am_encoder_infer_sess.run(
- None, input_feed={'text': phone_ids[i].numpy()}
- )
+ None, input_feed={'text': phone_ids[i].numpy()})
hs = orig_hs[0]
am_decoder_output = self.engine.executor.am_decoder_sess.run(
- None, input_feed={'xs': hs})
+ None, input_feed={'xs': hs})
am_postnet_output = self.engine.executor.am_postnet_sess.run(
- None,
- input_feed={
- 'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
- })
+ None,
+ input_feed={
+ 'xs': np.transpose(am_decoder_output[0], (0, 2, 1))
+ })
am_output_data = am_decoder_output + np.transpose(
am_postnet_output[0], (0, 2, 1))
normalized_mel = am_output_data[0][0]
- mel = denorm(normalized_mel, self.engine.executor.am_mu, self.engine.executor.am_std)
+ mel = denorm(normalized_mel, self.engine.executor.am_mu,
+ self.engine.executor.am_std)
wav = self.engine.executor.voc_sess.run(
- output_names=None, input_feed={'logmel': mel})[0]
+ output_names=None, input_feed={'logmel': mel})[0]
wav_list.append(wav)
wavs = np.concatenate(wav_list)
return wavs
-
+
def streamTTS(self, text):
get_tone_ids = False
@@ -88,9 +88,7 @@ class TTS:
# front
input_ids = self.frontend.get_input_ids(
- text,
- merge_sentences=merge_sentences,
- get_tone_ids=get_tone_ids)
+ text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
for i in range(len(phone_ids)):
@@ -105,14 +103,15 @@ class TTS:
mel = mel[0]
# voc streaming
- mel_chunks = get_chunks(mel, self.config.voc_block, self.config.voc_pad, "voc")
+ mel_chunks = get_chunks(mel, self.config.voc_block,
+ self.config.voc_pad, "voc")
voc_chunk_num = len(mel_chunks)
for i, mel_chunk in enumerate(mel_chunks):
sub_wav = self.executor.voc_sess.run(
output_names=None, input_feed={'logmel': mel_chunk})
- sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i,
- self.config.voc_block, self.config.voc_pad,
- self.config.voc_upsample)
+ sub_wav = self.depadding(
+ sub_wav[0], voc_chunk_num, i, self.config.voc_block,
+ self.config.voc_pad, self.config.voc_upsample)
yield self.after_process(sub_wav)
@@ -130,7 +129,8 @@ class TTS:
end = min(self.config.voc_block + self.config.voc_pad, mel_len)
# streaming am
- hss = get_chunks(orig_hs, self.config.am_block, self.config.am_pad, "am")
+ hss = get_chunks(orig_hs, self.config.am_block,
+ self.config.am_pad, "am")
am_chunk_num = len(hss)
for i, hs in enumerate(hss):
am_decoder_output = self.executor.am_decoder_sess.run(
@@ -147,7 +147,8 @@ class TTS:
sub_mel = denorm(normalized_mel, self.executor.am_mu,
self.executor.am_std)
sub_mel = self.depadding(sub_mel, am_chunk_num, i,
- self.config.am_block, self.config.am_pad, 1)
+ self.config.am_block,
+ self.config.am_pad, 1)
if i == 0:
mel_streaming = sub_mel
@@ -165,23 +166,22 @@ class TTS:
output_names=None, input_feed={'logmel': voc_chunk})
sub_wav = self.depadding(
sub_wav[0], voc_chunk_num, voc_chunk_id,
- self.config.voc_block, self.config.voc_pad, self.config.voc_upsample)
+ self.config.voc_block, self.config.voc_pad,
+ self.config.voc_upsample)
yield self.after_process(sub_wav)
voc_chunk_id += 1
- start = max(
- 0, voc_chunk_id * self.config.voc_block - self.config.voc_pad)
- end = min(
- (voc_chunk_id + 1) * self.config.voc_block + self.config.voc_pad,
- mel_len)
+ start = max(0, voc_chunk_id * self.config.voc_block -
+ self.config.voc_pad)
+ end = min((voc_chunk_id + 1) * self.config.voc_block +
+ self.config.voc_pad, mel_len)
else:
logging.error(
"Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts."
- )
+ )
-
def streamTTSBytes(self, text):
for wav in self.engine.executor.infer(
text=text,
@@ -191,19 +191,14 @@ class TTS:
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
yield wav_bytes
-
-
+
def after_process(self, wav):
# for tvm
wav = float2pcm(wav) # float32 to int16
wav_bytes = wav.tobytes() # to bytes
wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64
return wav_base64
-
+
def streamTTS_TVM(self, text):
# 用 TVM 优化
pass
-
-
-
-
\ No newline at end of file
diff --git a/demos/speech_web/speech_server/src/SpeechBase/vpr.py b/demos/speech_web/speech_server/src/SpeechBase/vpr.py
index 29ee986e3..46e9c0366 100644
--- a/demos/speech_web/speech_server/src/SpeechBase/vpr.py
+++ b/demos/speech_web/speech_server/src/SpeechBase/vpr.py
@@ -1,11 +1,14 @@
# vpr Demo 没有使用 mysql 与 muilvs, 仅用于docker演示
import logging
+
import faiss
-from matplotlib import use
import numpy as np
+from matplotlib import use
+
from .sql_helper import DataBase
from .vpr_encode import get_audio_embedding
+
class VPR:
def __init__(self, db_path, dim, top_k) -> None:
# 初始化
@@ -14,15 +17,15 @@ class VPR:
self.top_k = top_k
self.dtype = np.float32
self.vpr_idx = 0
-
+
# db 初始化
self.db = DataBase(db_path)
-
+
# faiss 初始化
index_ip = faiss.IndexFlatIP(dim)
self.index_ip = faiss.IndexIDMap(index_ip)
self.init()
-
+
def init(self):
# demo 初始化,把 mysql中的向量注册到 faiss 中
sql_dbs = self.db.select_all()
@@ -34,12 +37,13 @@ class VPR:
if len(vc.shape) == 1:
vc = np.expand_dims(vc, axis=0)
# 构建数据库
- self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64'))
+ self.index_ip.add_with_ids(vc, np.array(
+ (idx, )).astype('int64'))
logging.info("faiss 构建完毕")
-
+
def faiss_enroll(self, idx, vc):
- self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64'))
-
+ self.index_ip.add_with_ids(vc, np.array((idx, )).astype('int64'))
+
def vpr_enroll(self, username, wav_path):
# 注册声纹
emb = get_audio_embedding(wav_path)
@@ -53,21 +57,22 @@ class VPR:
else:
last_idx, mess = None
return last_idx
-
+
def vpr_recog(self, wav_path):
# 识别声纹
emb_search = get_audio_embedding(wav_path)
-
+
if emb_search is not None:
emb_search = np.expand_dims(emb_search, axis=0)
D, I = self.index_ip.search(emb_search, self.top_k)
D = D.tolist()[0]
- I = I.tolist()[0]
- return [(round(D[i] * 100, 2 ), I[i]) for i in range(len(D)) if I[i] != -1]
+ I = I.tolist()[0]
+ return [(round(D[i] * 100, 2), I[i]) for i in range(len(D))
+ if I[i] != -1]
else:
logging.error("识别失败")
return None
-
+
def do_search_vpr(self, wav_path):
spk_ids, paths, scores = [], [], []
recog_result = self.vpr_recog(wav_path)
@@ -78,41 +83,39 @@ class VPR:
scores.append(score)
paths.append("")
return spk_ids, paths, scores
-
+
def vpr_del(self, username):
# 根据用户username, 删除声纹
# 查用户ID,删除对应向量
res = self.db.select_by_username(username)
for r in res:
idx = r['id']
- self.index_ip.remove_ids(np.array((idx,)).astype('int64'))
-
+ self.index_ip.remove_ids(np.array((idx, )).astype('int64'))
+
self.db.drop_by_username(username)
-
+
def vpr_list(self):
# 获取数据列表
return self.db.select_all()
-
+
def do_list(self):
spk_ids, vpr_ids = [], []
for res in self.db.select_all():
spk_ids.append(res['username'])
vpr_ids.append(res['id'])
- return spk_ids, vpr_ids
-
+ return spk_ids, vpr_ids
+
def do_get_wav(self, vpr_idx):
- res = self.db.select_by_id(vpr_idx)
- return res[0]['wavpath']
-
-
+ res = self.db.select_by_id(vpr_idx)
+ return res[0]['wavpath']
+
def vpr_data(self, idx):
# 获取对应ID的数据
res = self.db.select_by_id(idx)
return res
-
+
def vpr_droptable(self):
# 删除表
self.db.drop_table()
# 清空 faiss
self.index_ip.reset()
-
diff --git a/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py b/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py
index a6a00e4d0..9d052fd98 100644
--- a/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py
+++ b/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py
@@ -1,9 +1,12 @@
-from paddlespeech.cli.vector import VectorExecutor
-import numpy as np
import logging
+import numpy as np
+
+from paddlespeech.cli.vector import VectorExecutor
+
vector_executor = VectorExecutor()
+
def get_audio_embedding(path):
"""
Use vpr_inference to generate embedding of audio
@@ -16,5 +19,3 @@ def get_audio_embedding(path):
except Exception as e:
logging.error(f"Error with embedding:{e}")
return None
-
-
\ No newline at end of file
diff --git a/demos/speech_web/speech_server/src/WebsocketManeger.py b/demos/speech_web/speech_server/src/WebsocketManeger.py
index 5edde8430..954d849a5 100644
--- a/demos/speech_web/speech_server/src/WebsocketManeger.py
+++ b/demos/speech_web/speech_server/src/WebsocketManeger.py
@@ -2,6 +2,7 @@ from typing import List
from fastapi import WebSocket
+
class ConnectionManager:
def __init__(self):
# 存放激活的ws连接对象
@@ -28,4 +29,4 @@ class ConnectionManager:
await connection.send_text(message)
-manager = ConnectionManager()
\ No newline at end of file
+manager = ConnectionManager()
diff --git a/demos/speech_web/speech_server/src/robot.py b/demos/speech_web/speech_server/src/robot.py
index b971c57b5..031d91eb2 100644
--- a/demos/speech_web/speech_server/src/robot.py
+++ b/demos/speech_web/speech_server/src/robot.py
@@ -1,60 +1,65 @@
-from paddlespeech.cli.asr.infer import ASRExecutor
-import soundfile as sf
import os
-import librosa
+import librosa
+import soundfile as sf
from src.SpeechBase.asr import ASR
-from src.SpeechBase.tts import TTS
from src.SpeechBase.nlp import NLP
+from src.SpeechBase.tts import TTS
+
+from paddlespeech.cli.asr.infer import ASRExecutor
class Robot:
- def __init__(self, asr_config, tts_config,asr_init_path,
+ def __init__(self,
+ asr_config,
+ tts_config,
+ asr_init_path,
ie_model_path=None) -> None:
self.nlp = NLP(ie_model_path=ie_model_path)
self.asr = ASR(config_path=asr_config)
self.tts = TTS(config_path=tts_config)
self.tts_sample_rate = 24000
self.asr_sample_rate = 16000
-
+
# 流式识别效果不如端到端的模型,这里流式模型与端到端模型分开
self.asr_model = ASRExecutor()
self.asr_name = "conformer_wenetspeech"
self.warm_up_asrmodel(asr_init_path)
-
- def warm_up_asrmodel(self, asr_init_path):
+ def warm_up_asrmodel(self, asr_init_path):
if not os.path.exists(asr_init_path):
path_dir = os.path.dirname(asr_init_path)
if not os.path.exists(path_dir):
os.makedirs(path_dir, exist_ok=True)
-
+
# TTS生成,采样率24000
text = "生成初始音频"
self.text2speech(text, asr_init_path)
-
+
# asr model初始化
- self.asr_model(asr_init_path, model=self.asr_name,lang='zh',
- sample_rate=16000, force_yes=True)
-
-
+ self.asr_model(
+ asr_init_path,
+ model=self.asr_name,
+ lang='zh',
+ sample_rate=16000,
+ force_yes=True)
+
def speech2text(self, audio_file):
self.asr_model.preprocess(self.asr_name, audio_file)
self.asr_model.infer(self.asr_name)
res = self.asr_model.postprocess()
return res
-
+
def text2speech(self, text, outpath):
wav = self.tts.offlineTTS(text)
- sf.write(
- outpath, wav, samplerate=self.tts_sample_rate)
+ sf.write(outpath, wav, samplerate=self.tts_sample_rate)
res = wav
return res
-
+
def text2speechStream(self, text):
for sub_wav_base64 in self.tts.streamTTS(text=text):
yield sub_wav_base64
-
+
def text2speechStreamBytes(self, text):
for wav_bytes in self.tts.streamTTSBytes(text=text):
yield wav_bytes
@@ -66,5 +71,3 @@ class Robot:
def ie(self, text):
result = self.nlp.ie(text)
return result
-
-
\ No newline at end of file
diff --git a/demos/speech_web/speech_server/src/util.py b/demos/speech_web/speech_server/src/util.py
index 34005d919..4a566b6ee 100644
--- a/demos/speech_web/speech_server/src/util.py
+++ b/demos/speech_web/speech_server/src/util.py
@@ -1,18 +1,13 @@
import random
+
def randName(n=5):
- return "".join(random.sample('zyxwvutsrqponmlkjihgfedcba',n))
+ return "".join(random.sample('zyxwvutsrqponmlkjihgfedcba', n))
+
def SuccessRequest(result=None, message="ok"):
- return {
- "code": 0,
- "result":result,
- "message": message
- }
+ return {"code": 0, "result": result, "message": message}
+
def ErrorRequest(result=None, message="error"):
- return {
- "code": -1,
- "result":result,
- "message": message
- }
\ No newline at end of file
+ return {"code": -1, "result": result, "message": message}
diff --git a/demos/streaming_asr_server/local/rtf_from_log.py b/demos/streaming_asr_server/local/rtf_from_log.py
index 4b89b48fd..09a9c9750 100755
--- a/demos/streaming_asr_server/local/rtf_from_log.py
+++ b/demos/streaming_asr_server/local/rtf_from_log.py
@@ -34,7 +34,7 @@ if __name__ == '__main__':
n = 0
for m in rtfs:
# not accurate, may have duplicate log
- n += 1
+ n += 1
T += m['T']
P += m['P']
diff --git a/docs/requirements.txt b/docs/requirements.txt
index bf1486c5e..a77da9f82 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,12 +1,6 @@
-myst-parser
-numpydoc
-recommonmark>=0.5.0
-sphinx
-sphinx-autobuild
-sphinx-markdown-tables
-sphinx_rtd_theme
-paddlepaddle>=2.2.2
+braceexpandcolorlog
editdistance
+fastapi
g2p_en
g2pM
h5py
@@ -14,39 +8,44 @@ inflect
jieba
jsonlines
kaldiio
+keyboard
librosa==0.8.1
loguru
matplotlib
+myst-parser
nara_wpe
+numpydoc
onnxruntime
-pandas
paddlenlp
+paddlepaddle>=2.2.2
paddlespeech_feat
+pandas
+pathos == 0.2.8
+pattern_singleton
Pillow>=9.0.0
praatio==5.0.0
+prettytable
pypinyin
pypinyin-dict
python-dateutil
pyworld==0.2.12
+recommonmark>=0.5.0
resampy==0.2.2
sacrebleu
scipy
sentencepiece~=0.1.96
soundfile~=0.10
+sphinx
+sphinx-autobuild
+sphinx-markdown-tables
+sphinx_rtd_theme
textgrid
timer
tqdm
typeguard
+uvicorn
visualdl
webrtcvad
+websockets
yacs~=0.1.8
-prettytable
zhon
-colorlog
-pathos == 0.2.8
-fastapi
-websockets
-keyboard
-uvicorn
-pattern_singleton
-braceexpand
\ No newline at end of file
diff --git a/docs/source/conf.py b/docs/source/conf.py
index c94cf0b86..cd9b1807b 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -20,10 +20,11 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
+import os
+import sys
+
import recommonmark.parser
import sphinx_rtd_theme
-import sys
-import os
sys.path.insert(0, os.path.abspath('../..'))
autodoc_mock_imports = ["soundfile", "librosa"]
diff --git a/examples/csmsc/tts1/README.md b/examples/csmsc/tts1/README.md
index 0725eda2d..a5b48e6b4 100644
--- a/examples/csmsc/tts1/README.md
+++ b/examples/csmsc/tts1/README.md
@@ -113,9 +113,9 @@ optional arguments:
--transformer-tts-stat TRANSFORMER_TTS_STAT
mean and standard deviation used to normalize
spectrogram when training transformer tts.
- --waveflow-config WAVEFLOW_CONFIG
+ --voc-config WAVEFLOW_CONFIG
waveflow config file.
- --waveflow-checkpoint WAVEFLOW_CHECKPOINT
+ --voc-checkpoint WAVEFLOW_CHECKPOINT
waveflow checkpoint to load.
--phones-dict PHONES_DICT
phone vocabulary file.
@@ -150,9 +150,9 @@ optional arguments:
--transformer-tts-stat TRANSFORMER_TTS_STAT
mean and standard deviation used to normalize
spectrogram when training transformer tts.
- --waveflow-config WAVEFLOW_CONFIG
+ --voc-config WAVEFLOW_CONFIG
waveflow config file.
- --waveflow-checkpoint WAVEFLOW_CHECKPOINT
+ --voc-ckpt WAVEFLOW_CHECKPOINT
waveflow checkpoint to load.
--phones-dict PHONES_DICT
phone vocabulary file.
@@ -170,14 +170,14 @@ optional arguments:
## Pretrained Model
Pretrained Model can be downloaded here:
-- [transformer_tts_csmsc_ckpt.zip](https://pan.baidu.com/s/1jan_ZXCGKI7DHvS2jxWIEw?pwd=9i0t)
+- [transformer_tts_csmsc_ckpt.zip](https://pan.baidu.com/s/1-6uvjQDxS0-6c9XZPBYqBQ?pwd=jjc3)
TransformerTTS checkpoint contains files listed below.
```text
transformer_tts_csmsc_ckpt
├── default.yaml # default config used to train transformer_tts
├── phone_id_map.txt # phone vocabulary file when training transformer_tts
-├── snapshot_iter_1118250.pdz # model parameters and optimizer states
+├── snapshot_iter_675000.pdz # model parameters and optimizer states
└── speech_stats.npy # statistics used to normalize spectrogram when training transformer_tts
```
You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained transformer_tts and waveflow models.
@@ -190,8 +190,8 @@ python3 ${BIN_DIR}/synthesize_e2e.py \
--transformer-tts-config=transformer_tts_csmsc_ckpt/default.yaml \
--transformer-tts-checkpoint=transformer_tts_csmsc_ckpt/snapshot_iter_1118250.pdz \
--transformer-tts-stat=transformer_tts_csmsc_ckpt/speech_stats.npy \
- --waveflow-config=waveflow_ljspeech_ckpt_0.3/config.yaml \
- --waveflow-checkpoint=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \
+ --voc-config=waveflow_ljspeech_ckpt_0.3/config.yaml \
+ --voc-ckpt=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \
--text=${BIN_DIR}/../sentences.txt \
--output-dir=exp/default/test_e2e \
--phones-dict=transformer_tts_csmsc_ckpt/phone_id_map.txt
diff --git a/examples/csmsc/tts1/conf/default.yaml b/examples/csmsc/tts1/conf/default.yaml
index 456b6a1e3..d4a62b836 100644
--- a/examples/csmsc/tts1/conf/default.yaml
+++ b/examples/csmsc/tts1/conf/default.yaml
@@ -1,17 +1,16 @@
-
-fs : 22050 # Hz, sample rate
-n_fft : 1024 # FFT size (samples).
-win_length : 1024 # Window length (samples). 46.4ms
-n_shift : 256 # Hop size (samples). 11.6ms
-fmin : 0 # Hz, min frequency when converting to mel
-fmax : 8000 # Hz, max frequency when converting to mel
+fs : 24000 # Hz, sample rate
+n_fft : 2048 # FFT size (samples).
+win_length : 1200 # Window length (samples). 46.4ms
+n_shift : 300 # Hop size (samples). 11.6ms
+fmin : 80 # Hz, min frequency when converting to mel
+fmax : 7600 # Hz, max frequency when converting to mel
n_mels : 80 # mel bands
window: "hann" # Window function.
###########################################################
# DATA SETTING #
###########################################################
-batch_size: 16
+batch_size: 4
num_workers: 2
##########################################################
@@ -82,11 +81,11 @@ optimizer:
###########################################################
# TRAINING SETTING #
###########################################################
-max_epoch: 500
+max_epoch: 300
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
-seed: 10086
\ No newline at end of file
+seed: 10086
diff --git a/examples/csmsc/tts1/local/preprocess.sh b/examples/csmsc/tts1/local/preprocess.sh
index f92664cef..ac5e8ec58 100644
--- a/examples/csmsc/tts1/local/preprocess.sh
+++ b/examples/csmsc/tts1/local/preprocess.sh
@@ -5,15 +5,27 @@ stop_stage=100
config_path=$1
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ # get durations from MFA's result
+ echo "Generate durations.txt from MFA results ..."
+ python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
+ --inputdir=./baker_alignment_tone \
+ --output=durations.txt \
+ --config=${config_path}
+fi
+
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# extract features
echo "Extract features ..."
- python3 ${BIN_DIR}/preprocess.py \
- --dataset=ljspeech \
+ python3 ${BIN_DIR}/preprocess_new.py \
+ --dataset=baker\
--rootdir=~/datasets/BZNSYP/ \
--dumpdir=dump \
- --config-path=conf/default.yaml \
- --num-cpu=8
+ --dur-file=durations.txt
+ --config-path=${config_path} \
+ --num-cpu=8 \
+ --cut-sil=True
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
diff --git a/examples/csmsc/tts1/local/synthesize.sh b/examples/csmsc/tts1/local/synthesize.sh
index 9d1c47b39..8c14f647d 100644
--- a/examples/csmsc/tts1/local/synthesize.sh
+++ b/examples/csmsc/tts1/local/synthesize.sh
@@ -3,15 +3,93 @@
config_path=$1
train_output_path=$2
ckpt_name=$3
+stage=0
+stop_stage=0
-FLAGS_allocator_strategy=naive_best_fit \
-FLAGS_fraction_of_gpu_memory_to_use=0.01 \
-python3 ${BIN_DIR}/synthesize.py \
- --transformer-tts-config=${config_path} \
- --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
- --transformer-tts-stat=dump/train/speech_stats.npy \
- --waveflow-config=waveflow_ljspeech_ckpt_0.3/config.yaml \
- --waveflow-checkpoint=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \
- --test-metadata=dump/test/norm/metadata.jsonl \
- --output-dir=${train_output_path}/test \
- --phones-dict=dump/phone_id_map.txt
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/synthesize.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=pwgan_csmsc \
+ --voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
+ --voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
+ --voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/test \
+ --phones_dict=dump/phone_id_map.txt
+fi
+
+# for more GAN Vocoders
+# multi band melgan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=mb_melgan_csmsc \
+ --voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \
+ --voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\
+ --voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/test \
+ --phones_dict=dump/phone_id_map.txt
+fi
+
+# style melgan
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=style_melgan_csmsc \
+ --voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
+ --voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
+ --voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/test \
+ --phones_dict=dump/phone_id_map.txt
+fi
+
+# hifigan
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "in hifigan syn"
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=hifigan_csmsc \
+ --voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
+ --voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
+ --voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/test \
+ --phones_dict=dump/phone_id_map.txt
+fi
+
+# wavernn
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "in wavernn syn"
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=wavernn_csmsc \
+ --voc_config=wavernn_csmsc_ckpt_0.2.0/default.yaml \
+ --voc_ckpt=wavernn_csmsc_ckpt_0.2.0/snapshot_iter_400000.pdz \
+ --voc_stat=wavernn_csmsc_ckpt_0.2.0/feats_stats.npy \
+ --test_metadata=dump/test/norm/metadata.jsonl \
+ --output_dir=${train_output_path}/test \
+ --phones_dict=dump/phone_id_map.txt
+fi
\ No newline at end of file
diff --git a/examples/csmsc/tts1/local/synthesize_e2e.sh b/examples/csmsc/tts1/local/synthesize_e2e.sh
index 25a862f90..ce2f30afd 100644
--- a/examples/csmsc/tts1/local/synthesize_e2e.sh
+++ b/examples/csmsc/tts1/local/synthesize_e2e.sh
@@ -4,14 +4,106 @@ config_path=$1
train_output_path=$2
ckpt_name=$3
-FLAGS_allocator_strategy=naive_best_fit \
-FLAGS_fraction_of_gpu_memory_to_use=0.01 \
-python3 ${BIN_DIR}/synthesize_e2e.py \
- --transformer-tts-config=${config_path} \
- --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
- --transformer-tts-stat=dump/train/speech_stats.npy \
- --waveflow-config=waveflow_ljspeech_ckpt_0.3/config.yaml \
- --waveflow-checkpoint=waveflow_ljspeech_ckpt_0.3/step-2000000.pdparams \
- --text=${BIN_DIR}/../sentences_en.txt \
- --output-dir=${train_output_path}/test_e2e \
- --phones-dict=dump/phone_id_map.txt
+stage=0
+stop_stage=0
+
+# pwgan
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=pwgan_csmsc \
+ --voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
+ --voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
+ --voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
+ --lang=zh \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/test_e2e \
+ --phones_dict=dump/phone_id_map.txt \
+ #--inference_dir=${train_output_path}/inference
+
+fi
+
+# for more GAN Vocoders
+# multi band melgan
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=mb_melgan_csmsc \
+ --voc_config=mb_melgan_csmsc_ckpt_0.1.1/default.yaml \
+ --voc_ckpt=mb_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1000000.pdz\
+ --voc_stat=mb_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
+ --lang=zh \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/test_e2e \
+ --phones_dict=dump/phone_id_map.txt \
+ #--inference_dir=${train_output_path}/inference
+fi
+
+# the pretrained models haven't release now
+# style melgan
+# style melgan's Dygraph to Static Graph is not ready now
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=style_melgan_csmsc \
+ --voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
+ --voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
+ --voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
+ --lang=zh \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/test_e2e \
+ --phones_dict=dump/phone_id_map.txt
+ # --inference_dir=${train_output_path}/inference
+fi
+
+# hifigan
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "in hifigan syn_e2e"
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=hifigan_csmsc \
+ --voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
+ --voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
+ --voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \
+ --lang=zh \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/test_e2e \
+ --phones_dict=dump/phone_id_map.txt \
+ #--inference_dir=${train_output_path}/inference
+fi
+
+# wavernn
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "in wavernn syn_e2e"
+ FLAGS_allocator_strategy=naive_best_fit \
+ FLAGS_fraction_of_gpu_memory_to_use=0.01 \
+ python3 ${BIN_DIR}/../synthesize_e2e.py \
+ --transformer-tts-config=${config_path} \
+ --transformer-tts-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
+ --transformer-tts-stat=dump/train/speech_stats.npy \
+ --voc=wavernn_csmsc \
+ --voc_config=wavernn_csmsc_ckpt_0.2.0/default.yaml \
+ --voc_ckpt=wavernn_csmsc_ckpt_0.2.0/snapshot_iter_400000.pdz \
+ --voc_stat=wavernn_csmsc_ckpt_0.2.0/feats_stats.npy \
+ --lang=zh \
+ --text=${BIN_DIR}/../sentences.txt \
+ --output_dir=${train_output_path}/test_e2e \
+ --phones_dict=dump/phone_id_map.txt \
+ #--inference_dir=${train_output_path}/inference
+fi
diff --git a/examples/ernie_sat/local/inference.py b/examples/ernie_sat/local/inference.py
index e6a0788fd..cb0aba5c1 100644
--- a/examples/ernie_sat/local/inference.py
+++ b/examples/ernie_sat/local/inference.py
@@ -26,15 +26,15 @@ from align import words2phns
from align import words2phns_zh
from paddle import nn
from sedit_arg_parser import parse_args
+
+from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn
+from paddlespeech.t2s.models.ernie_sat.mlm import build_model_from_file
from utils import eval_durs
from utils import get_voc_out
from utils import is_chinese
from utils import load_num_sequence_text
from utils import read_2col_text
-from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn
-from paddlespeech.t2s.models.ernie_sat.mlm import build_model_from_file
-
random.seed(0)
np.random.seed(0)
diff --git a/examples/ernie_sat/local/inference_new.py b/examples/ernie_sat/local/inference_new.py
index 525967eb1..53a46298b 100644
--- a/examples/ernie_sat/local/inference_new.py
+++ b/examples/ernie_sat/local/inference_new.py
@@ -27,15 +27,15 @@ from align import words2phns
from align import words2phns_zh
from paddle import nn
from sedit_arg_parser import parse_args
+from yacs.config import CfgNode
+
+from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn
+from paddlespeech.t2s.models.ernie_sat.ernie_sat import ErnieSAT
from utils import eval_durs
from utils import get_voc_out
from utils import is_chinese
from utils import load_num_sequence_text
from utils import read_2col_text
-from yacs.config import CfgNode
-
-from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn
-from paddlespeech.t2s.models.ernie_sat.ernie_sat import ErnieSAT
random.seed(0)
np.random.seed(0)
diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py
index 4b1c0ef3d..b781c4a8e 100644
--- a/paddlespeech/__init__.py
+++ b/paddlespeech/__init__.py
@@ -14,5 +14,3 @@
import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])
-
-
diff --git a/paddlespeech/audio/__init__.py b/paddlespeech/audio/__init__.py
index 83be8e32e..a91958105 100644
--- a/paddlespeech/audio/__init__.py
+++ b/paddlespeech/audio/__init__.py
@@ -14,12 +14,12 @@
from . import compliance
from . import datasets
from . import features
-from . import text
-from . import transform
-from . import streamdata
from . import functional
from . import io
from . import metric
from . import sox_effects
+from . import streamdata
+from . import text
+from . import transform
from .backends import load
from .backends import save
diff --git a/paddlespeech/audio/streamdata/__init__.py b/paddlespeech/audio/streamdata/__init__.py
index 753fcc11b..47a2e79b3 100644
--- a/paddlespeech/audio/streamdata/__init__.py
+++ b/paddlespeech/audio/streamdata/__init__.py
@@ -4,67 +4,66 @@
# Modified from https://github.com/webdataset/webdataset
#
# flake8: noqa
-
-from .cache import (
- cached_tarfile_samples,
- cached_tarfile_to_samples,
- lru_cleanup,
- pipe_cleaner,
-)
-from .compat import WebDataset, WebLoader, FluidWrapper
-from .extradatasets import MockDataset, with_epoch, with_length
-from .filters import (
- associate,
- batched,
- decode,
- detshuffle,
- extract_keys,
- getfirst,
- info,
- map,
- map_dict,
- map_tuple,
- pipelinefilter,
- rename,
- rename_keys,
- audio_resample,
- select,
- shuffle,
- slice,
- to_tuple,
- transform_with,
- unbatched,
- xdecode,
- audio_data_filter,
- audio_tokenize,
- audio_resample,
- audio_compute_fbank,
- audio_spec_aug,
- sort,
- audio_padding,
- audio_cmvn,
- placeholder,
-)
-from .handlers import (
- ignore_and_continue,
- ignore_and_stop,
- reraise_exception,
- warn_and_continue,
- warn_and_stop,
-)
+from .cache import cached_tarfile_samples
+from .cache import cached_tarfile_to_samples
+from .cache import lru_cleanup
+from .cache import pipe_cleaner
+from .compat import FluidWrapper
+from .compat import WebDataset
+from .compat import WebLoader
+from .extradatasets import MockDataset
+from .extradatasets import with_epoch
+from .extradatasets import with_length
+from .filters import associate
+from .filters import audio_cmvn
+from .filters import audio_compute_fbank
+from .filters import audio_data_filter
+from .filters import audio_padding
+from .filters import audio_resample
+from .filters import audio_spec_aug
+from .filters import audio_tokenize
+from .filters import batched
+from .filters import decode
+from .filters import detshuffle
+from .filters import extract_keys
+from .filters import getfirst
+from .filters import info
+from .filters import map
+from .filters import map_dict
+from .filters import map_tuple
+from .filters import pipelinefilter
+from .filters import placeholder
+from .filters import rename
+from .filters import rename_keys
+from .filters import select
+from .filters import shuffle
+from .filters import slice
+from .filters import sort
+from .filters import to_tuple
+from .filters import transform_with
+from .filters import unbatched
+from .filters import xdecode
+from .handlers import ignore_and_continue
+from .handlers import ignore_and_stop
+from .handlers import reraise_exception
+from .handlers import warn_and_continue
+from .handlers import warn_and_stop
+from .mix import RandomMix
+from .mix import RoundRobin
from .pipeline import DataPipeline
-from .shardlists import (
- MultiShardSample,
- ResampledShards,
- SimpleShardList,
- non_empty,
- resampled,
- shardspec,
- single_node_only,
- split_by_node,
- split_by_worker,
-)
-from .tariterators import tarfile_samples, tarfile_to_samples
-from .utils import PipelineStage, repeatedly
-from .writer import ShardWriter, TarWriter, numpy_dumps
-from .mix import RandomMix, RoundRobin
+from .shardlists import MultiShardSample
+from .shardlists import non_empty
+from .shardlists import resampled
+from .shardlists import ResampledShards
+from .shardlists import shardspec
+from .shardlists import SimpleShardList
+from .shardlists import single_node_only
+from .shardlists import split_by_node
+from .shardlists import split_by_worker
+from .tariterators import tarfile_samples
+from .tariterators import tarfile_to_samples
+from .utils import PipelineStage
+from .utils import repeatedly
+from .writer import numpy_dumps
+from .writer import ShardWriter
+from .writer import TarWriter
diff --git a/paddlespeech/audio/streamdata/autodecode.py b/paddlespeech/audio/streamdata/autodecode.py
index ca0e2ea2f..d7f7937bd 100644
--- a/paddlespeech/audio/streamdata/autodecode.py
+++ b/paddlespeech/audio/streamdata/autodecode.py
@@ -5,18 +5,19 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
-
"""Automatically decode webdataset samples."""
-
-import io, json, os, pickle, re, tempfile
+import io
+import json
+import os
+import pickle
+import re
+import tempfile
from functools import partial
import numpy as np
-
"""Extensions passed on to the image decoder."""
image_extensions = "jpg jpeg png ppm pgm pbm pnm".split()
-
################################################################
# handle basic datatypes
################################################################
@@ -128,7 +129,7 @@ def call_extension_handler(key, data, f, extensions):
target = target.split(".")
if len(target) > len(extension):
continue
- if extension[-len(target) :] == target:
+ if extension[-len(target):] == target:
return f(data)
return None
@@ -268,7 +269,6 @@ def imagehandler(imagespec, extensions=image_extensions):
################################################################
# torch video
################################################################
-
'''
def torch_video(key, data):
"""Decode video using the torchvideo library.
@@ -289,7 +289,6 @@ def torch_video(key, data):
return torchvision.io.read_video(fname, pts_unit="sec")
'''
-
################################################################
# paddlespeech.audio
################################################################
@@ -359,7 +358,6 @@ def gzfilter(key, data):
# decode entire training amples
################################################################
-
default_pre_handlers = [gzfilter]
default_post_handlers = [basichandlers]
@@ -387,7 +385,8 @@ class Decoder:
pre = default_pre_handlers
if post is None:
post = default_post_handlers
- assert all(callable(h) for h in handlers), f"one of {handlers} not callable"
+ assert all(callable(h)
+ for h in handlers), f"one of {handlers} not callable"
assert all(callable(h) for h in pre), f"one of {pre} not callable"
assert all(callable(h) for h in post), f"one of {post} not callable"
self.handlers = pre + handlers + post
diff --git a/paddlespeech/audio/streamdata/cache.py b/paddlespeech/audio/streamdata/cache.py
index e7bbffa1b..5cd94aa6c 100644
--- a/paddlespeech/audio/streamdata/cache.py
+++ b/paddlespeech/audio/streamdata/cache.py
@@ -2,7 +2,11 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
-import itertools, os, random, re, sys
+import itertools
+import os
+import random
+import re
+import sys
from urllib.parse import urlparse
from . import filters
@@ -40,7 +44,7 @@ def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False):
os.remove(fname)
-def download(url, dest, chunk_size=1024 ** 2, verbose=False):
+def download(url, dest, chunk_size=1024**2, verbose=False):
"""Download a file from `url` to `dest`."""
temp = dest + f".temp{os.getpid()}"
with gopen.gopen(url) as stream:
@@ -65,12 +69,11 @@ def pipe_cleaner(spec):
def get_file_cached(
- spec,
- cache_size=-1,
- cache_dir=None,
- url_to_name=pipe_cleaner,
- verbose=False,
-):
+ spec,
+ cache_size=-1,
+ cache_dir=None,
+ url_to_name=pipe_cleaner,
+ verbose=False, ):
if cache_size == -1:
cache_size = default_cache_size
if cache_dir is None:
@@ -107,15 +110,14 @@ verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0"))
def cached_url_opener(
- data,
- handler=reraise_exception,
- cache_size=-1,
- cache_dir=None,
- url_to_name=pipe_cleaner,
- validator=check_tar_format,
- verbose=False,
- always=False,
-):
+ data,
+ handler=reraise_exception,
+ cache_size=-1,
+ cache_dir=None,
+ url_to_name=pipe_cleaner,
+ validator=check_tar_format,
+ verbose=False,
+ always=False, ):
"""Given a stream of url names (packaged in `dict(url=url)`), yield opened streams."""
verbose = verbose or verbose_cache
for sample in data:
@@ -132,8 +134,7 @@ def cached_url_opener(
cache_size=cache_size,
cache_dir=cache_dir,
url_to_name=url_to_name,
- verbose=verbose,
- )
+ verbose=verbose, )
if verbose:
print("# opening %s" % dest, file=sys.stderr)
assert os.path.exists(dest)
@@ -143,9 +144,8 @@ def cached_url_opener(
data = f.read(200)
os.remove(dest)
raise ValueError(
- "%s (%s) is not a tar archive, but a %s, contains %s"
- % (dest, url, ftype, repr(data))
- )
+ "%s (%s) is not a tar archive, but a %s, contains %s" %
+ (dest, url, ftype, repr(data)))
try:
stream = open(dest, "rb")
sample.update(stream=stream)
@@ -158,7 +158,7 @@ def cached_url_opener(
continue
raise exn
except Exception as exn:
- exn.args = exn.args + (url,)
+ exn.args = exn.args + (url, )
if handler(exn):
continue
else:
@@ -166,14 +166,13 @@ def cached_url_opener(
def cached_tarfile_samples(
- src,
- handler=reraise_exception,
- cache_size=-1,
- cache_dir=None,
- verbose=False,
- url_to_name=pipe_cleaner,
- always=False,
-):
+ src,
+ handler=reraise_exception,
+ cache_size=-1,
+ cache_dir=None,
+ verbose=False,
+ url_to_name=pipe_cleaner,
+ always=False, ):
streams = cached_url_opener(
src,
handler=handler,
@@ -181,8 +180,7 @@ def cached_tarfile_samples(
cache_dir=cache_dir,
verbose=verbose,
url_to_name=url_to_name,
- always=always,
- )
+ always=always, )
samples = tar_file_and_group_expander(streams, handler=handler)
return samples
diff --git a/paddlespeech/audio/streamdata/compat.py b/paddlespeech/audio/streamdata/compat.py
index deda53384..84f46abff 100644
--- a/paddlespeech/audio/streamdata/compat.py
+++ b/paddlespeech/audio/streamdata/compat.py
@@ -6,13 +6,18 @@ from dataclasses import dataclass
from itertools import islice
from typing import List
-import braceexpand, yaml
+import braceexpand
+import yaml
from . import autodecode
-from . import cache, filters, shardlists, tariterators
+from . import cache
+from . import filters
+from . import shardlists
+from . import tariterators
from .filters import reraise_exception
+from .paddle_utils import DataLoader
+from .paddle_utils import IterableDataset
from .pipeline import DataPipeline
-from .paddle_utils import DataLoader, IterableDataset
class FluidInterface:
@@ -26,7 +31,8 @@ class FluidInterface:
return self.compose(filters.unbatched())
def listed(self, batchsize, partial=True):
- return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None)
+ return self.compose(
+ filters.batched(), batchsize=batchsize, collation_fn=None)
def unlisted(self):
return self.compose(filters.unlisted())
@@ -43,9 +49,19 @@ class FluidInterface:
def map(self, f, handler=reraise_exception):
return self.compose(filters.map(f, handler=handler))
- def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception):
- handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args]
- decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial)
+ def decode(self,
+ *args,
+ pre=None,
+ post=None,
+ only=None,
+ partial=False,
+ handler=reraise_exception):
+ handlers = [
+ autodecode.ImageHandler(x) if isinstance(x, str) else x
+ for x in args
+ ]
+ decoder = autodecode.Decoder(
+ handlers, pre=pre, post=post, only=only, partial=partial)
return self.map(decoder, handler=handler)
def map_dict(self, handler=reraise_exception, **kw):
@@ -80,12 +96,12 @@ class FluidInterface:
def audio_data_filter(self, *args, **kw):
return self.compose(filters.audio_data_filter(*args, **kw))
-
+
def audio_tokenize(self, *args, **kw):
return self.compose(filters.audio_tokenize(*args, **kw))
def resample(self, *args, **kw):
- return self.compose(filters.resample(*args, **kw))
+ return self.compose(filters.resample(*args, **kw))
def audio_compute_fbank(self, *args, **kw):
return self.compose(filters.audio_compute_fbank(*args, **kw))
@@ -102,27 +118,28 @@ class FluidInterface:
def audio_cmvn(self, cmvn_file):
return self.compose(filters.audio_cmvn(cmvn_file))
+
class WebDataset(DataPipeline, FluidInterface):
"""Small fluid-interface wrapper for DataPipeline."""
def __init__(
- self,
- urls,
- handler=reraise_exception,
- resampled=False,
- repeat=False,
- shardshuffle=None,
- cache_size=0,
- cache_dir=None,
- detshuffle=False,
- nodesplitter=shardlists.single_node_only,
- verbose=False,
- ):
+ self,
+ urls,
+ handler=reraise_exception,
+ resampled=False,
+ repeat=False,
+ shardshuffle=None,
+ cache_size=0,
+ cache_dir=None,
+ detshuffle=False,
+ nodesplitter=shardlists.single_node_only,
+ verbose=False, ):
super().__init__()
if isinstance(urls, IterableDataset):
assert not resampled
self.append(urls)
- elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")):
+ elif isinstance(urls, str) and (urls.endswith(".yaml") or
+ urls.endswith(".yml")):
with (open(urls)) as stream:
spec = yaml.safe_load(stream)
assert "datasets" in spec
@@ -152,9 +169,7 @@ class WebDataset(DataPipeline, FluidInterface):
handler=handler,
verbose=verbose,
cache_size=cache_size,
- cache_dir=cache_dir,
- )
- )
+ cache_dir=cache_dir, ))
class FluidWrapper(DataPipeline, FluidInterface):
diff --git a/paddlespeech/audio/streamdata/extradatasets.py b/paddlespeech/audio/streamdata/extradatasets.py
index e6d617724..0933908a3 100644
--- a/paddlespeech/audio/streamdata/extradatasets.py
+++ b/paddlespeech/audio/streamdata/extradatasets.py
@@ -5,13 +5,10 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
-
-
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
-
import itertools as itt
import os
import random
@@ -63,8 +60,7 @@ class repeatedly(IterableDataset, PipelineStage):
return utils.repeatedly(
source,
nepochs=self.nepochs,
- nbatches=self.nbatches,
- )
+ nbatches=self.nbatches, )
class with_epoch(IterableDataset):
diff --git a/paddlespeech/audio/streamdata/filters.py b/paddlespeech/audio/streamdata/filters.py
index 82b9c6bab..4056203fb 100644
--- a/paddlespeech/audio/streamdata/filters.py
+++ b/paddlespeech/audio/streamdata/filters.py
@@ -3,7 +3,6 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
-
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
"""A collection of iterators for data transformations.
@@ -12,28 +11,32 @@ These functions are plain iterator functions. You can find curried versions
in webdataset.filters, and you can find IterableDataset wrappers in
webdataset.processing.
"""
-
import io
-from fnmatch import fnmatch
+import itertools
+import os
+import random
import re
-import itertools, os, random, sys, time
-from functools import reduce, wraps
+import sys
+import time
+from fnmatch import fnmatch
+from functools import reduce
+from functools import wraps
import numpy as np
+import paddle
from . import autodecode
-from . import utils
-from .paddle_utils import PaddleTensor
-from .utils import PipelineStage
-
+from . import utils
from .. import backends
from ..compliance import kaldi
-import paddle
from ..transform.cmvn import GlobalCMVN
-from ..utils.tensor_utils import pad_sequence
-from ..transform.spec_augment import time_warp
-from ..transform.spec_augment import time_mask
from ..transform.spec_augment import freq_mask
+from ..transform.spec_augment import time_mask
+from ..transform.spec_augment import time_warp
+from ..utils.tensor_utils import pad_sequence
+from .paddle_utils import PaddleTensor
+from .utils import PipelineStage
+
class FilterFunction(object):
"""Helper class for currying pipeline stages.
@@ -159,10 +162,12 @@ def transform_with(sample, transformers):
result[i] = f(sample[i])
return result
+
###
# Iterators
###
+
def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""):
"""Print information about the samples that are passing through.
@@ -325,15 +330,24 @@ def _rename(data, handler=reraise_exception, keep=True, **kw):
for sample in data:
try:
if not keep:
- yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}
+ yield {
+ k: getfirst(sample, v, missing_is_error=True)
+ for k, v in kw.items()
+ }
else:
def listify(v):
return v.split(";") if isinstance(v, str) else v
to_be_replaced = {x for v in kw.values() for x in listify(v)}
- result = {k: v for k, v in sample.items() if k not in to_be_replaced}
- result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()})
+ result = {
+ k: v
+ for k, v in sample.items() if k not in to_be_replaced
+ }
+ result.update({
+ k: getfirst(sample, v, missing_is_error=True)
+ for k, v in kw.items()
+ })
yield result
except Exception as exn:
if handler(exn):
@@ -381,7 +395,11 @@ def _map_dict(data, handler=reraise_exception, **kw):
map_dict = pipelinefilter(_map_dict)
-def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None):
+def _to_tuple(data,
+ *args,
+ handler=reraise_exception,
+ missing_is_error=True,
+ none_is_error=None):
"""Convert dict samples to tuples."""
if none_is_error is None:
none_is_error = missing_is_error
@@ -390,7 +408,10 @@ def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, non
for sample in data:
try:
- result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args])
+ result = tuple([
+ getfirst(sample, f, missing_is_error=missing_is_error)
+ for f in args
+ ])
if none_is_error and any(x is None for x in result):
raise ValueError(f"to_tuple {args} got {sample.keys()}")
yield result
@@ -463,19 +484,28 @@ rsample = pipelinefilter(_rsample)
slice = pipelinefilter(itertools.islice)
-def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False):
+def _extract_keys(source,
+ *patterns,
+ duplicate_is_error=True,
+ ignore_missing=False):
for sample in source:
result = []
for pattern in patterns:
- pattern = pattern.split(";") if isinstance(pattern, str) else pattern
- matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)]
+ pattern = pattern.split(";") if isinstance(pattern,
+ str) else pattern
+ matches = [
+ x for x in sample.keys()
+ if any(fnmatch("." + x, p) for p in pattern)
+ ]
if len(matches) == 0:
if ignore_missing:
continue
else:
- raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.")
+ raise ValueError(
+ f"Cannot find {pattern} in sample keys {sample.keys()}.")
if len(matches) > 1 and duplicate_is_error:
- raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.")
+ raise ValueError(
+ f"Multiple sample keys {sample.keys()} match {pattern}.")
value = sample[matches[0]]
result.append(value)
yield tuple(result)
@@ -484,7 +514,12 @@ def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=Fal
extract_keys = pipelinefilter(_extract_keys)
-def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw):
+def _rename_keys(source,
+ *args,
+ keep_unselected=False,
+ must_match=True,
+ duplicate_is_error=True,
+ **kw):
renamings = [(pattern, output) for output, pattern in args]
renamings += [(pattern, output) for output, pattern in kw.items()]
for sample in source:
@@ -504,11 +539,15 @@ def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicat
continue
if new_name in new_sample:
if duplicate_is_error:
- raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.")
+ raise ValueError(
+ f"Duplicate value in sample {sample.keys()} after rename."
+ )
continue
new_sample[new_name] = value
if must_match and not all(matched.values()):
- raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).")
+ raise ValueError(
+ f"Not all patterns ({matched}) matched sample keys ({sample.keys()})."
+ )
yield new_sample
@@ -541,18 +580,18 @@ def find_decoder(decoders, path):
if fname.startswith("__"):
return lambda x: x
for pattern, fun in decoders[::-1]:
- if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern):
+ if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(),
+ pattern):
return fun
return None
def _xdecode(
- source,
- *args,
- must_decode=True,
- defaults=default_decoders,
- **kw,
-):
+ source,
+ *args,
+ must_decode=True,
+ defaults=default_decoders,
+ **kw, ):
decoders = list(defaults) + list(args)
decoders += [("*." + k, v) for k, v in kw.items()]
for sample in source:
@@ -575,18 +614,18 @@ def _xdecode(
new_sample[path] = value
yield new_sample
-xdecode = pipelinefilter(_xdecode)
+xdecode = pipelinefilter(_xdecode)
def _audio_data_filter(source,
- frame_shift=10,
- max_length=10240,
- min_length=10,
- token_max_length=200,
- token_min_length=1,
- min_output_input_ratio=0.0005,
- max_output_input_ratio=1):
+ frame_shift=10,
+ max_length=10240,
+ min_length=10,
+ token_max_length=200,
+ token_min_length=1,
+ min_output_input_ratio=0.0005,
+ max_output_input_ratio=1):
""" Filter sample according to feature and label length
Inplace operation.
@@ -613,7 +652,8 @@ def _audio_data_filter(source,
assert 'wav' in sample
assert 'label' in sample
# sample['wav'] is paddle.Tensor, we have 100 frames every second (default)
- num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift)
+ num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (
+ 1000 / frame_shift)
if num_frames < min_length:
continue
if num_frames > max_length:
@@ -629,13 +669,15 @@ def _audio_data_filter(source,
continue
yield sample
+
audio_data_filter = pipelinefilter(_audio_data_filter)
+
def _audio_tokenize(source,
- symbol_table,
- bpe_model=None,
- non_lang_syms=None,
- split_with_space=False):
+ symbol_table,
+ bpe_model=None,
+ non_lang_syms=None,
+ split_with_space=False):
""" Decode text to chars or BPE
Inplace operation
@@ -693,8 +735,10 @@ def _audio_tokenize(source,
sample['label'] = label
yield sample
+
audio_tokenize = pipelinefilter(_audio_tokenize)
+
def _audio_resample(source, resample_rate=16000):
""" Resample data.
Inplace operation.
@@ -713,18 +757,22 @@ def _audio_resample(source, resample_rate=16000):
waveform = sample['wav']
if sample_rate != resample_rate:
sample['sample_rate'] = resample_rate
- sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample(
- waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate
- ))
+ sample['wav'] = paddle.to_tensor(
+ backends.soundfile_backend.resample(
+ waveform.numpy(),
+ src_sr=sample_rate,
+ target_sr=resample_rate))
yield sample
+
audio_resample = pipelinefilter(_audio_resample)
+
def _audio_compute_fbank(source,
- num_mel_bins=80,
- frame_length=25,
- frame_shift=10,
- dither=0.0):
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0):
""" Extract fbank
Args:
@@ -746,30 +794,33 @@ def _audio_compute_fbank(source,
waveform = sample['wav']
waveform = waveform * (1 << 15)
# Only keep fname, feat, label
- mat = kaldi.fbank(waveform,
- n_mels=num_mel_bins,
- frame_length=frame_length,
- frame_shift=frame_shift,
- dither=dither,
- energy_floor=0.0,
- sr=sample_rate)
+ mat = kaldi.fbank(
+ waveform,
+ n_mels=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ sr=sample_rate)
yield dict(fname=sample['fname'], label=sample['label'], feat=mat)
audio_compute_fbank = pipelinefilter(_audio_compute_fbank)
-def _audio_spec_aug(source,
- max_w=5,
- w_inplace=True,
- w_mode="PIL",
- max_f=30,
- num_f_mask=2,
- f_inplace=True,
- f_replace_with_zero=False,
- max_t=40,
- num_t_mask=2,
- t_inplace=True,
- t_replace_with_zero=False,):
+
+def _audio_spec_aug(
+ source,
+ max_w=5,
+ w_inplace=True,
+ w_mode="PIL",
+ max_f=30,
+ num_f_mask=2,
+ f_inplace=True,
+ f_replace_with_zero=False,
+ max_t=40,
+ num_t_mask=2,
+ t_inplace=True,
+ t_replace_with_zero=False, ):
""" Do spec augmentation
Inplace operation
@@ -793,12 +844,23 @@ def _audio_spec_aug(source,
for sample in source:
x = sample['feat']
x = x.numpy()
- x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode)
- x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero)
- x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero)
+ x = time_warp(x, max_time_warp=max_w, inplace=w_inplace, mode=w_mode)
+ x = freq_mask(
+ x,
+ F=max_f,
+ n_mask=num_f_mask,
+ inplace=f_inplace,
+ replace_with_zero=f_replace_with_zero)
+ x = time_mask(
+ x,
+ T=max_t,
+ n_mask=num_t_mask,
+ inplace=t_inplace,
+ replace_with_zero=t_replace_with_zero)
sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32)
yield sample
+
audio_spec_aug = pipelinefilter(_audio_spec_aug)
@@ -829,8 +891,10 @@ def _sort(source, sort_size=500):
for x in buf:
yield x
+
sort = pipelinefilter(_sort)
+
def _batched(source, batch_size=16):
""" Static batch the data by `batch_size`
@@ -850,8 +914,10 @@ def _batched(source, batch_size=16):
if len(buf) > 0:
yield buf
+
batched = pipelinefilter(_batched)
+
def dynamic_batched(source, max_frames_in_batch=12000):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
@@ -892,8 +958,8 @@ def _audio_padding(source):
"""
for sample in source:
assert isinstance(sample, list)
- feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample],
- dtype="int64")
+ feats_length = paddle.to_tensor(
+ [x['feat'].shape[0] for x in sample], dtype="int64")
order = paddle.argsort(feats_length, descending=True)
feats_lengths = paddle.to_tensor(
[sample[i]['feat'].shape[0] for i in order], dtype="int64")
@@ -902,20 +968,20 @@ def _audio_padding(source):
sorted_labels = [
paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order
]
- label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels],
- dtype="int64")
- padded_feats = pad_sequence(sorted_feats,
- batch_first=True,
- padding_value=0)
- padding_labels = pad_sequence(sorted_labels,
- batch_first=True,
- padding_value=-1)
-
- yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
+ label_lengths = paddle.to_tensor(
+ [x.shape[0] for x in sorted_labels], dtype="int64")
+ padded_feats = pad_sequence(
+ sorted_feats, batch_first=True, padding_value=0)
+ padding_labels = pad_sequence(
+ sorted_labels, batch_first=True, padding_value=-1)
+
+ yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
label_lengths)
+
audio_padding = pipelinefilter(_audio_padding)
+
def _audio_cmvn(source, cmvn_file):
global_cmvn = GlobalCMVN(cmvn_file)
for batch in source:
@@ -923,13 +989,16 @@ def _audio_cmvn(source, cmvn_file):
padded_feats = padded_feats.numpy()
padded_feats = global_cmvn(padded_feats)
padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32)
- yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
- label_lengths)
+ yield (sorted_keys, padded_feats, feats_lengths, padding_labels,
+ label_lengths)
+
audio_cmvn = pipelinefilter(_audio_cmvn)
+
def _placeholder(source):
for data in source:
yield data
+
placeholder = pipelinefilter(_placeholder)
diff --git a/paddlespeech/audio/streamdata/gopen.py b/paddlespeech/audio/streamdata/gopen.py
index 457d048a6..60a434603 100644
--- a/paddlespeech/audio/streamdata/gopen.py
+++ b/paddlespeech/audio/streamdata/gopen.py
@@ -3,12 +3,12 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
-
-
"""Open URLs by calling subcommands."""
-
-import os, sys, re
-from subprocess import PIPE, Popen
+import os
+import re
+import sys
+from subprocess import PIPE
+from subprocess import Popen
from urllib.parse import urlparse
# global used for printing additional node information during verbose output
@@ -31,14 +31,13 @@ class Pipe:
"""
def __init__(
- self,
- *args,
- mode=None,
- timeout=7200.0,
- ignore_errors=False,
- ignore_status=[],
- **kw,
- ):
+ self,
+ *args,
+ mode=None,
+ timeout=7200.0,
+ ignore_errors=False,
+ ignore_status=[],
+ **kw, ):
"""Create an IO Pipe."""
self.ignore_errors = ignore_errors
self.ignore_status = [0] + ignore_status
@@ -75,8 +74,7 @@ class Pipe:
if verbose:
print(
f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}",
- file=sys.stderr,
- )
+ file=sys.stderr, )
if self.status not in self.ignore_status and not self.ignore_errors:
raise Exception(f"{self.args}: exit {self.status} (read) {info}")
@@ -114,9 +112,11 @@ class Pipe:
self.close()
-def set_options(
- obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None
-):
+def set_options(obj,
+ timeout=None,
+ ignore_errors=None,
+ ignore_status=None,
+ handler=None):
"""Set options for Pipes.
This function can be called on any stream. It will set pipe options only
@@ -168,16 +168,14 @@ def gopen_pipe(url, mode="rb", bufsize=8192):
mode=mode,
shell=True,
bufsize=bufsize,
- ignore_status=[141],
- ) # skipcq: BAN-B604
+ ignore_status=[141], ) # skipcq: BAN-B604
elif mode[0] == "w":
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
- ignore_status=[141],
- ) # skipcq: BAN-B604
+ ignore_status=[141], ) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
@@ -196,8 +194,7 @@ def gopen_curl(url, mode="rb", bufsize=8192):
mode=mode,
shell=True,
bufsize=bufsize,
- ignore_status=[141, 23],
- ) # skipcq: BAN-B604
+ ignore_status=[141, 23], ) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"curl -s -L -T - '{url}'"
return Pipe(
@@ -205,8 +202,7 @@ def gopen_curl(url, mode="rb", bufsize=8192):
mode=mode,
shell=True,
bufsize=bufsize,
- ignore_status=[141, 26],
- ) # skipcq: BAN-B604
+ ignore_status=[141, 26], ) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
@@ -226,15 +222,13 @@ def gopen_htgs(url, mode="rb", bufsize=8192):
mode=mode,
shell=True,
bufsize=bufsize,
- ignore_status=[141, 23],
- ) # skipcq: BAN-B604
+ ignore_status=[141, 23], ) # skipcq: BAN-B604
elif mode[0] == "w":
raise ValueError(f"{mode}: cannot write")
else:
raise ValueError(f"{mode}: unknown mode")
-
def gopen_gsutil(url, mode="rb", bufsize=8192):
"""Open a URL with `curl`.
@@ -249,8 +243,7 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
mode=mode,
shell=True,
bufsize=bufsize,
- ignore_status=[141, 23],
- ) # skipcq: BAN-B604
+ ignore_status=[141, 23], ) # skipcq: BAN-B604
elif mode[0] == "w":
cmd = f"gsutil cp - '{url}'"
return Pipe(
@@ -258,13 +251,11 @@ def gopen_gsutil(url, mode="rb", bufsize=8192):
mode=mode,
shell=True,
bufsize=bufsize,
- ignore_status=[141, 26],
- ) # skipcq: BAN-B604
+ ignore_status=[141, 26], ) # skipcq: BAN-B604
else:
raise ValueError(f"{mode}: unknown mode")
-
def gopen_error(url, *args, **kw):
"""Raise a value error.
@@ -285,8 +276,7 @@ gopen_schemes = dict(
ftps=gopen_curl,
scp=gopen_curl,
gs=gopen_gsutil,
- htgs=gopen_htgs,
-)
+ htgs=gopen_htgs, )
def gopen(url, mode="rb", bufsize=8192, **kw):
diff --git a/paddlespeech/audio/streamdata/handlers.py b/paddlespeech/audio/streamdata/handlers.py
index 7f3d28b62..0173e5373 100644
--- a/paddlespeech/audio/streamdata/handlers.py
+++ b/paddlespeech/audio/streamdata/handlers.py
@@ -3,7 +3,6 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
-
"""Pluggable exception handlers.
These are functions that take an exception as an argument and then return...
@@ -14,8 +13,8 @@ These are functions that take an exception as an argument and then return...
They are used as handler= arguments in much of the library.
"""
-
-import time, warnings
+import time
+import warnings
def reraise_exception(exn):
diff --git a/paddlespeech/audio/streamdata/mix.py b/paddlespeech/audio/streamdata/mix.py
index 7d790f00f..8ab01a1c8 100644
--- a/paddlespeech/audio/streamdata/mix.py
+++ b/paddlespeech/audio/streamdata/mix.py
@@ -5,16 +5,21 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
-
"""Classes for mixing samples from multiple sources."""
-
-import itertools, os, random, time, sys
-from functools import reduce, wraps
+import itertools
+import os
+import random
+import sys
+import time
+from functools import reduce
+from functools import wraps
import numpy as np
-from . import autodecode, utils
-from .paddle_utils import PaddleTensor, IterableDataset
+from . import autodecode
+from . import utils
+from .paddle_utils import IterableDataset
+from .paddle_utils import PaddleTensor
from .utils import PipelineStage
diff --git a/paddlespeech/audio/streamdata/paddle_utils.py b/paddlespeech/audio/streamdata/paddle_utils.py
index 02bc4c841..2368ac675 100644
--- a/paddlespeech/audio/streamdata/paddle_utils.py
+++ b/paddlespeech/audio/streamdata/paddle_utils.py
@@ -5,10 +5,8 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
-
"""Mock implementations of paddle interfaces when paddle is not available."""
-
try:
from paddle.io import DataLoader, IterableDataset
except ModuleNotFoundError:
@@ -23,6 +21,7 @@ except ModuleNotFoundError:
pass
+
try:
from paddle import Tensor as PaddleTensor
except ModuleNotFoundError:
diff --git a/paddlespeech/audio/streamdata/pipeline.py b/paddlespeech/audio/streamdata/pipeline.py
index 7339a762a..53e5dc246 100644
--- a/paddlespeech/audio/streamdata/pipeline.py
+++ b/paddlespeech/audio/streamdata/pipeline.py
@@ -3,15 +3,21 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#%%
-import copy, os, random, sys, time
+import copy
+import os
+import random
+import sys
+import time
from dataclasses import dataclass
from itertools import islice
from typing import List
-import braceexpand, yaml
+import braceexpand
+import yaml
from .handlers import reraise_exception
-from .paddle_utils import DataLoader, IterableDataset
+from .paddle_utils import DataLoader
+from .paddle_utils import IterableDataset
from .utils import PipelineStage
@@ -22,8 +28,7 @@ def add_length_method(obj):
Combined = type(
obj.__class__.__name__ + "_Length",
(obj.__class__, IterableDataset),
- {"__len__": length},
- )
+ {"__len__": length}, )
obj.__class__ = Combined
return obj
diff --git a/paddlespeech/audio/streamdata/shardlists.py b/paddlespeech/audio/streamdata/shardlists.py
index cfaf9a64b..5b6c64351 100644
--- a/paddlespeech/audio/streamdata/shardlists.py
+++ b/paddlespeech/audio/streamdata/shardlists.py
@@ -4,28 +4,30 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
-
# Modified from https://github.com/webdataset/webdataset
-
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
-
-import os, random, sys, time
-from dataclasses import dataclass, field
+import os
+import random
+import sys
+import time
+from dataclasses import dataclass
+from dataclasses import field
from itertools import islice
from typing import List
-import braceexpand, yaml
+import braceexpand
+import yaml
from . import utils
+from ..utils.log import Logger
from .filters import pipelinefilter
from .paddle_utils import IterableDataset
+logger = Logger(__name__)
-from ..utils.log import Logger
-logger = Logger(__name__)
def expand_urls(urls):
if isinstance(urls, str):
urllist = urls.split("::")
@@ -64,7 +66,8 @@ class SimpleShardList(IterableDataset):
def split_by_node(src, group=None):
- rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
+ rank, world_size, worker, num_workers = utils.paddle_worker_info(
+ group=group)
logger.info(f"world_size:{world_size}, rank:{rank}")
if world_size > 1:
for s in islice(src, rank, None, world_size):
@@ -75,9 +78,11 @@ def split_by_node(src, group=None):
def single_node_only(src, group=None):
- rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
+ rank, world_size, worker, num_workers = utils.paddle_worker_info(
+ group=group)
if world_size > 1:
- raise ValueError("input pipeline needs to be reconfigured for multinode training")
+ raise ValueError(
+ "input pipeline needs to be reconfigured for multinode training")
for s in src:
yield s
@@ -104,7 +109,8 @@ def resampled_(src, n=sys.maxsize):
rng = random.Random(seed)
print("# resampled loading", file=sys.stderr)
items = list(src)
- print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
+ print(
+ f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr)
for i in range(n):
yield rng.choice(items)
@@ -118,7 +124,9 @@ def non_empty(src):
yield s
count += 1
if count == 0:
- raise ValueError("pipeline stage received no data at all and this was declared as an error")
+ raise ValueError(
+ "pipeline stage received no data at all and this was declared as an error"
+ )
@dataclass
@@ -142,6 +150,8 @@ class MultiShardSample(IterableDataset):
def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec."""
self.epoch = -1
+
+
class MultiShardSample(IterableDataset):
def __init__(self, fname):
"""Construct a shardlist from multiple sources using a YAML spec."""
@@ -156,20 +166,23 @@ class MultiShardSample(IterableDataset):
else:
with open(fname) as stream:
spec = yaml.safe_load(stream)
- assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys())
+ assert set(spec.keys()).issubset(
+ set("prefix datasets buckets".split())), list(spec.keys())
prefix = expand(spec.get("prefix", ""))
self.sources = []
for ds in spec["datasets"]:
- assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list(
- ds.keys()
- )
+ assert set(ds.keys()).issubset(
+ set("buckets name shards resample choose".split())), list(
+ ds.keys())
buckets = ds.get("buckets", spec.get("buckets", []))
if isinstance(buckets, str):
buckets = [buckets]
buckets = [expand(s) for s in buckets]
if buckets == []:
buckets = [""]
- assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
+ assert len(
+ buckets
+ ) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented"
bucket = buckets[0]
name = ds.get("name", "@" + bucket)
urls = ds["shards"]
@@ -177,15 +190,19 @@ class MultiShardSample(IterableDataset):
urls = [urls]
# urls = [u for url in urls for u in braceexpand.braceexpand(url)]
urls = [
- prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url))
+ prefix + os.path.join(bucket, u)
+ for url in urls for u in braceexpand.braceexpand(expand(url))
]
resample = ds.get("resample", -1)
nsample = ds.get("choose", -1)
if nsample > len(urls):
- raise ValueError(f"perepoch {nsample} must be no greater than the number of shards")
+ raise ValueError(
+ f"perepoch {nsample} must be no greater than the number of shards"
+ )
if (nsample > 0) and (resample > 0):
raise ValueError("specify only one of perepoch or choose")
- entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample)
+ entry = MSSource(
+ name=name, urls=urls, perepoch=nsample, resample=resample)
self.sources.append(entry)
print(f"# {name} {len(urls)} {nsample}", file=sys.stderr)
@@ -203,7 +220,7 @@ class MultiShardSample(IterableDataset):
# sample without replacement
l = list(source.urls)
self.rng.shuffle(l)
- l = l[: source.perepoch]
+ l = l[:source.perepoch]
else:
l = list(source.urls)
result += l
@@ -227,12 +244,11 @@ class ResampledShards(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
- self,
- urls,
- nshards=sys.maxsize,
- worker_seed=None,
- deterministic=False,
- ):
+ self,
+ urls,
+ nshards=sys.maxsize,
+ worker_seed=None,
+ deterministic=False, ):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
@@ -252,7 +268,8 @@ class ResampledShards(IterableDataset):
if self.deterministic:
seed = utils.make_seed(self.worker_seed(), self.epoch)
else:
- seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4))
+ seed = utils.make_seed(self.worker_seed(), self.epoch,
+ os.getpid(), time.time_ns(), os.urandom(4))
if os.environ.get("WDS_SHOW_SEED", "0") == "1":
print(f"# ResampledShards seed {seed}")
self.rng = random.Random(seed)
diff --git a/paddlespeech/audio/streamdata/tariterators.py b/paddlespeech/audio/streamdata/tariterators.py
index b1616918c..79b81c0ce 100644
--- a/paddlespeech/audio/streamdata/tariterators.py
+++ b/paddlespeech/audio/streamdata/tariterators.py
@@ -3,13 +3,12 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
-
# Modified from https://github.com/webdataset/webdataset
# Modified from wenet(https://github.com/wenet-e2e/wenet)
-
"""Low level iteration functions for tar archives."""
-
-import random, re, tarfile
+import random
+import re
+import tarfile
import braceexpand
@@ -27,6 +26,7 @@ import numpy as np
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
+
def base_plus_ext(path):
"""Split off all file extensions.
@@ -47,12 +47,8 @@ def valid_sample(sample):
:param sample: sample to be checked
"""
- return (
- sample is not None
- and isinstance(sample, dict)
- and len(list(sample.keys())) > 0
- and not sample.get("__bad__", False)
- )
+ return (sample is not None and isinstance(sample, dict) and
+ len(list(sample.keys())) > 0 and not sample.get("__bad__", False))
# FIXME: UNUSED
@@ -79,16 +75,16 @@ def url_opener(data, handler=reraise_exception, **kw):
sample.update(stream=stream)
yield sample
except Exception as exn:
- exn.args = exn.args + (url,)
+ exn.args = exn.args + (url, )
if handler(exn):
continue
else:
break
-def tar_file_iterator(
- fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
-):
+def tar_file_iterator(fileobj,
+ skip_meta=r"__[^/]*__($|/)",
+ handler=reraise_exception):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
@@ -103,11 +99,8 @@ def tar_file_iterator(
continue
if fname is None:
continue
- if (
- "/" not in fname
- and fname.startswith(meta_prefix)
- and fname.endswith(meta_suffix)
- ):
+ if ("/" not in fname and fname.startswith(meta_prefix) and
+ fname.endswith(meta_suffix)):
# skipping metadata for now
continue
if skip_meta is not None and re.match(skip_meta, fname):
@@ -118,8 +111,10 @@ def tar_file_iterator(
assert pos > 0
prefix, postfix = name[:pos], name[pos + 1:]
if postfix == 'wav':
- waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False)
- result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate)
+ waveform, sample_rate = paddlespeech.audio.load(
+ stream.extractfile(tarinfo), normal=False)
+ result = dict(
+ fname=prefix, wav=waveform, sample_rate=sample_rate)
else:
txt = stream.extractfile(tarinfo).read().decode('utf8').strip()
result = dict(fname=prefix, txt=txt)
@@ -128,16 +123,17 @@ def tar_file_iterator(
stream.members = []
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
- exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
+ exn.args = (exn.args[0] + " @ " + str(fileobj), ) + exn.args[1:]
if handler(exn):
continue
else:
break
del stream
-def tar_file_and_group_iterator(
- fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception
-):
+
+def tar_file_and_group_iterator(fileobj,
+ skip_meta=r"__[^/]*__($|/)",
+ handler=reraise_exception):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
@@ -167,8 +163,11 @@ def tar_file_and_group_iterator(
if postfix == 'txt':
example['txt'] = file_obj.read().decode('utf8').strip()
elif postfix in AUDIO_FORMAT_SETS:
- waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False)
- waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32)
+ waveform, sample_rate = paddlespeech.audio.load(
+ file_obj, normal=False)
+ waveform = paddle.to_tensor(
+ np.expand_dims(np.array(waveform), 0),
+ dtype=paddle.float32)
example['wav'] = waveform
example['sample_rate'] = sample_rate
@@ -176,19 +175,21 @@ def tar_file_and_group_iterator(
example[postfix] = file_obj.read()
except Exception as exn:
if hasattr(exn, "args") and len(exn.args) > 0:
- exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:]
+ exn.args = (exn.args[0] + " @ " + str(fileobj),
+ ) + exn.args[1:]
if handler(exn):
continue
else:
break
valid = False
- # logging.warning('error to parse {}'.format(name))
+ # logging.warning('error to parse {}'.format(name))
prev_prefix = prefix
if prev_prefix is not None:
example['fname'] = prev_prefix
yield example
stream.close()
+
def tar_file_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
@@ -200,9 +201,8 @@ def tar_file_expander(data, handler=reraise_exception):
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_iterator(source["stream"]):
- assert (
- isinstance(sample, dict) and "data" in sample and "fname" in sample
- )
+ assert (isinstance(sample, dict) and "data" in sample and
+ "fname" in sample)
sample["__url__"] = url
yield sample
except Exception as exn:
@@ -213,8 +213,6 @@ def tar_file_expander(data, handler=reraise_exception):
break
-
-
def tar_file_and_group_expander(data, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
@@ -226,9 +224,8 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
assert isinstance(source, dict)
assert "stream" in source
for sample in tar_file_and_group_iterator(source["stream"]):
- assert (
- isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample
- )
+ assert (isinstance(sample, dict) and "wav" in sample and
+ "txt" in sample and "fname" in sample)
sample["__url__"] = url
yield sample
except Exception as exn:
@@ -239,7 +236,11 @@ def tar_file_and_group_expander(data, handler=reraise_exception):
break
-def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+def group_by_keys(data,
+ keys=base_plus_ext,
+ lcase=True,
+ suffixes=None,
+ handler=None):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
@@ -254,8 +255,8 @@ def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=N
print(
prefix,
suffix,
- current_sample.keys() if isinstance(current_sample, dict) else None,
- )
+ current_sample.keys()
+ if isinstance(current_sample, dict) else None, )
if prefix is None:
continue
if lcase:
diff --git a/paddlespeech/audio/streamdata/utils.py b/paddlespeech/audio/streamdata/utils.py
index c7294f2bf..74eea6a03 100644
--- a/paddlespeech/audio/streamdata/utils.py
+++ b/paddlespeech/audio/streamdata/utils.py
@@ -4,22 +4,24 @@
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
-
# Modified from https://github.com/webdataset/webdataset
-
"""Miscellaneous utility functions."""
-
import importlib
import itertools as itt
import os
import re
import sys
-from typing import Any, Callable, Iterator, Optional, Union
+from typing import Any
+from typing import Callable
+from typing import Iterator
+from typing import Optional
+from typing import Union
from ..utils.log import Logger
logger = Logger(__name__)
+
def make_seed(*args):
seed = 0
for arg in args:
@@ -37,7 +39,7 @@ def identity(x: Any) -> Any:
return x
-def safe_eval(s: str, expr: str = "{}"):
+def safe_eval(s: str, expr: str="{}"):
"""Evaluate the given expression more safely."""
if re.sub("[^A-Za-z0-9_]", "", s) != s:
raise ValueError(f"safe_eval: illegal characters in: '{s}'")
@@ -54,9 +56,9 @@ def lookup_sym(sym: str, modules: list):
return None
-def repeatedly0(
- loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize
-):
+def repeatedly0(loader: Iterator,
+ nepochs: int=sys.maxsize,
+ nbatches: int=sys.maxsize):
"""Repeatedly returns batches from a DataLoader."""
for epoch in range(nepochs):
for sample in itt.islice(loader, nbatches):
@@ -69,12 +71,11 @@ def guess_batchsize(batch: Union[tuple, list]):
def repeatedly(
- source: Iterator,
- nepochs: int = None,
- nbatches: int = None,
- nsamples: int = None,
- batchsize: Callable[..., int] = guess_batchsize,
-):
+ source: Iterator,
+ nepochs: int=None,
+ nbatches: int=None,
+ nsamples: int=None,
+ batchsize: Callable[..., int]=guess_batchsize, ):
"""Repeatedly yield samples from an iterator."""
epoch = 0
batch = 0
@@ -93,6 +94,7 @@ def repeatedly(
if nepochs is not None and epoch >= nepochs:
return
+
def paddle_worker_info(group=None):
"""Return node and worker info for PyTorch and some distributed environments."""
rank = 0
@@ -126,6 +128,7 @@ def paddle_worker_info(group=None):
return rank, world_size, worker, num_workers
+
def paddle_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = paddle_worker_info(group=group)
diff --git a/paddlespeech/audio/streamdata/writer.py b/paddlespeech/audio/streamdata/writer.py
index 7d4f7703b..3928a3ba6 100644
--- a/paddlespeech/audio/streamdata/writer.py
+++ b/paddlespeech/audio/streamdata/writer.py
@@ -5,18 +5,24 @@
# See the LICENSE file for licensing terms (BSD-style).
# Modified from https://github.com/webdataset/webdataset
#
-
"""Classes and functions for writing tar files and WebDataset files."""
-
-import io, json, pickle, re, tarfile, time
-from typing import Any, Callable, Optional, Union
+import io
+import json
+import pickle
+import re
+import tarfile
+import time
+from typing import Any
+from typing import Callable
+from typing import Optional
+from typing import Union
import numpy as np
from . import gopen
-def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622
+def imageencoder(image: Any, format: str="PNG"): # skipcq: PYL-W0622
"""Compress an image using PIL and return it as a string.
Can handle float or uint8 images.
@@ -67,6 +73,7 @@ def bytestr(data: Any):
return data.encode("ascii")
return str(data).encode("ascii")
+
def paddle_dumps(data: Any):
"""Dump data into a bytestring using paddle.dumps.
@@ -82,6 +89,7 @@ def paddle_dumps(data: Any):
paddle.save(data, stream)
return stream.getvalue()
+
def numpy_dumps(data: np.ndarray):
"""Dump data into a bytestring using numpy npy format.
@@ -139,9 +147,8 @@ def add_handlers(d, keys, value):
def make_handlers():
"""Create a list of handlers for encoding data."""
handlers = {}
- add_handlers(
- handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii")
- )
+ add_handlers(handlers, "cls cls2 class count index inx id",
+ lambda x: str(x).encode("ascii"))
add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8"))
add_handlers(handlers, "html htm", lambda x: x.encode("utf-8"))
add_handlers(handlers, "pyd pickle", pickle.dumps)
@@ -152,7 +159,8 @@ def make_handlers():
add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8"))
add_handlers(handlers, "mp msgpack msg", mp_dumps)
add_handlers(handlers, "cbor", cbor_dumps)
- add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg"))
+ add_handlers(handlers, "jpg jpeg img image",
+ lambda data: imageencoder(data, "jpg"))
add_handlers(handlers, "png", lambda data: imageencoder(data, "png"))
add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm"))
add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm"))
@@ -192,7 +200,8 @@ def encode_based_on_extension(sample: dict, handlers: dict):
:param handlers: handlers for encoding
"""
return {
- k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items())
+ k: encode_based_on_extension1(v, k, handlers)
+ for k, v in list(sample.items())
}
@@ -258,15 +267,14 @@ class TarWriter:
"""
def __init__(
- self,
- fileobj,
- user: str = "bigdata",
- group: str = "bigdata",
- mode: int = 0o0444,
- compress: Optional[bool] = None,
- encoder: Union[None, bool, Callable] = True,
- keep_meta: bool = False,
- ):
+ self,
+ fileobj,
+ user: str="bigdata",
+ group: str="bigdata",
+ mode: int=0o0444,
+ compress: Optional[bool]=None,
+ encoder: Union[None, bool, Callable]=True,
+ keep_meta: bool=False, ):
"""Create a tar writer.
:param fileobj: stream to write data to
@@ -330,8 +338,7 @@ class TarWriter:
continue
if not isinstance(v, (bytes, bytearray, memoryview)):
raise ValueError(
- f"{k} doesn't map to a bytes after encoding ({type(v)})"
- )
+ f"{k} doesn't map to a bytes after encoding ({type(v)})")
key = obj["__key__"]
for k in sorted(obj.keys()):
if k == "__key__":
@@ -349,7 +356,8 @@ class TarWriter:
ti.uname = self.user
ti.gname = self.group
if not isinstance(v, (bytes, bytearray, memoryview)):
- raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}")
+ raise ValueError(
+ f"converter didn't yield bytes: {k}, {type(v)}")
stream = io.BytesIO(v)
self.tarstream.addfile(ti, stream)
total += ti.size
@@ -360,14 +368,13 @@ class ShardWriter:
"""Like TarWriter but splits into multiple shards."""
def __init__(
- self,
- pattern: str,
- maxcount: int = 100000,
- maxsize: float = 3e9,
- post: Optional[Callable] = None,
- start_shard: int = 0,
- **kw,
- ):
+ self,
+ pattern: str,
+ maxcount: int=100000,
+ maxsize: float=3e9,
+ post: Optional[Callable]=None,
+ start_shard: int=0,
+ **kw, ):
"""Create a ShardWriter.
:param pattern: output file pattern
@@ -400,8 +407,7 @@ class ShardWriter:
self.fname,
self.count,
"%.1f GB" % (self.size / 1e9),
- self.total,
- )
+ self.total, )
self.shard += 1
stream = open(self.fname, "wb")
self.tarstream = TarWriter(stream, **self.kw)
@@ -413,11 +419,8 @@ class ShardWriter:
:param obj: sample to be written
"""
- if (
- self.tarstream is None
- or self.count >= self.maxcount
- or self.size >= self.maxsize
- ):
+ if (self.tarstream is None or self.count >= self.maxcount or
+ self.size >= self.maxsize):
self.next_stream()
size = self.tarstream.write(obj)
self.count += 1
diff --git a/paddlespeech/audio/text/text_featurizer.py b/paddlespeech/audio/text/text_featurizer.py
index 91c4d75c3..bcd6df54b 100644
--- a/paddlespeech/audio/text/text_featurizer.py
+++ b/paddlespeech/audio/text/text_featurizer.py
@@ -17,6 +17,7 @@ from typing import Union
import sentencepiece as spm
+from ..utils.log import Logger
from .utility import BLANK
from .utility import EOS
from .utility import load_dict
@@ -24,7 +25,6 @@ from .utility import MASKCTC
from .utility import SOS
from .utility import SPACE
from .utility import UNK
-from ..utils.log import Logger
logger = Logger(__name__)
diff --git a/paddlespeech/audio/transform/perturb.py b/paddlespeech/audio/transform/perturb.py
index 8044dc36f..0825caec8 100644
--- a/paddlespeech/audio/transform/perturb.py
+++ b/paddlespeech/audio/transform/perturb.py
@@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
+import io
+import os
+
+import h5py
import librosa
import numpy
+import numpy as np
import scipy
import soundfile
-import io
-import os
-import h5py
-import numpy as np
class SoundHDF5File():
"""Collecting sound files to a HDF5 file
@@ -109,6 +110,7 @@ class SoundHDF5File():
def close(self):
self.file.close()
+
class SpeedPerturbation():
"""SpeedPerturbation
@@ -558,4 +560,3 @@ class RIRConvolve():
[scipy.convolve(x, r, mode="same") for r in rir], axis=-1)
else:
return scipy.convolve(x, rir, mode="same")
-
diff --git a/paddlespeech/audio/transform/spec_augment.py b/paddlespeech/audio/transform/spec_augment.py
index 029e7b8f5..b2635066f 100644
--- a/paddlespeech/audio/transform/spec_augment.py
+++ b/paddlespeech/audio/transform/spec_augment.py
@@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
"""Spec Augment module for preprocessing i.e., data augmentation"""
import random
+
import numpy
from PIL import Image
diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py
index 3800c36db..b53eed88c 100644
--- a/paddlespeech/cli/executor.py
+++ b/paddlespeech/cli/executor.py
@@ -191,7 +191,7 @@ class BaseExecutor(ABC):
line = line.strip()
if not line:
continue
- k, v = line.split() # space or \t
+ k, v = line.split() # space or \t
job_contents[k] = v
return job_contents
diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py
index f6476b9aa..5fe2e16b9 100644
--- a/paddlespeech/s2t/__init__.py
+++ b/paddlespeech/s2t/__init__.py
@@ -114,6 +114,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
paddle.Tensor.new_full = new_full
paddle.static.Variable.new_full = new_full
+
def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
return xs
diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py
index cdad3b8f7..4978aab75 100644
--- a/paddlespeech/s2t/exps/u2/model.py
+++ b/paddlespeech/s2t/exps/u2/model.py
@@ -26,8 +26,8 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
-from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
+from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
@@ -109,7 +109,8 @@ class U2Trainer(Trainer):
def valid(self):
self.model.eval()
if not self.use_streamdata:
- logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
+ logger.info(
+ f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
@@ -136,7 +137,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
- msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
+ msg += "batch: {}/{}, ".format(i + 1,
+ len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
@@ -157,7 +159,8 @@ class U2Trainer(Trainer):
self.before_train()
if not self.use_streamdata:
- logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
+ logger.info(
+ f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
@@ -225,14 +228,18 @@ class U2Trainer(Trainer):
config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
- self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
- self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
+ self.train_loader = DataLoaderFactory.get_dataloader(
+ 'train', config, self.args)
+ self.valid_loader = DataLoaderFactory.get_dataloader(
+ 'valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
- self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
- self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
+ self.test_loader = DataLoaderFactory.get_dataloader('test', config,
+ self.args)
+ self.align_loader = DataLoaderFactory.get_dataloader(
+ 'align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py
index cb015c116..073d74293 100644
--- a/paddlespeech/s2t/exps/u2_kaldi/model.py
+++ b/paddlespeech/s2t/exps/u2_kaldi/model.py
@@ -105,7 +105,8 @@ class U2Trainer(Trainer):
def valid(self):
self.model.eval()
if not self.use_streamdata:
- logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
+ logger.info(
+ f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
@@ -133,7 +134,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
- msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
+ msg += "batch: {}/{}, ".format(i + 1,
+ len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
@@ -153,7 +155,8 @@ class U2Trainer(Trainer):
self.before_train()
if not self.use_streamdata:
- logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
+ logger.info(
+ f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
@@ -165,8 +168,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
- msg += "batch : {}/{}, ".format(batch_index + 1,
- len(self.train_loader))
+ msg += "batch : {}/{}, ".format(
+ batch_index + 1, len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
@@ -204,21 +207,24 @@ class U2Trainer(Trainer):
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
config = self.config.clone()
- self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
+ self.train_loader = DataLoaderFactory.get_dataloader(
+ 'train', config, self.args)
config = self.config.clone()
config['preprocess_config'] = None
- self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
+ self.valid_loader = DataLoaderFactory.get_dataloader(
+ 'valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
config = self.config.clone()
config['preprocess_config'] = None
- self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
+ self.test_loader = DataLoaderFactory.get_dataloader('test', config,
+ self.args)
config = self.config.clone()
config['preprocess_config'] = None
- self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
+ self.align_loader = DataLoaderFactory.get_dataloader(
+ 'align', config, self.args)
logger.info("Setup test/align Dataloader!")
-
def setup_model(self):
config = self.config
diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py
index 603825435..d57c49546 100644
--- a/paddlespeech/s2t/exps/u2_st/model.py
+++ b/paddlespeech/s2t/exps/u2_st/model.py
@@ -121,7 +121,8 @@ class U2STTrainer(Trainer):
def valid(self):
self.model.eval()
if not self.use_streamdata:
- logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
+ logger.info(
+ f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
@@ -155,7 +156,8 @@ class U2STTrainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
- msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
+ msg += "batch: {}/{}, ".format(i + 1,
+ len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
@@ -175,7 +177,8 @@ class U2STTrainer(Trainer):
self.before_train()
if not self.use_streamdata:
- logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
+ logger.info(
+ f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
@@ -248,14 +251,16 @@ class U2STTrainer(Trainer):
config['load_transcript'] = load_transcript
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
- self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
- self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
+ self.train_loader = DataLoaderFactory.get_dataloader(
+ 'train', config, self.args)
+ self.valid_loader = DataLoaderFactory.get_dataloader(
+ 'valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
- self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
+ self.test_loader = DataLoaderFactory.get_dataloader('test', config,
+ self.args)
logger.info("Setup test Dataloader!")
-
def setup_model(self):
config = self.config
model_conf = config
diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py
index 831830241..3aff5f59b 100644
--- a/paddlespeech/s2t/io/dataloader.py
+++ b/paddlespeech/s2t/io/dataloader.py
@@ -22,17 +22,16 @@ import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
+from yacs.config import CfgNode
+import paddlespeech.audio.streamdata as streamdata
+from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.batchfy import make_batchset
from paddlespeech.s2t.io.converter import CustomConverter
from paddlespeech.s2t.io.dataset import TransformDataset
from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from paddlespeech.s2t.utils.log import Log
-import paddlespeech.audio.streamdata as streamdata
-from paddlespeech.audio.text.text_featurizer import TextFeaturizer
-from yacs.config import CfgNode
-
__all__ = ["BatchDataLoader", "StreamDataLoader"]
logger = Log(__name__).getlog()
@@ -61,6 +60,7 @@ def batch_collate(x):
"""
return x[0]
+
def read_preprocess_cfg(preprocess_conf_file):
augment_conf = dict()
preprocess_cfg = CfgNode(new_allowed=True)
@@ -82,7 +82,8 @@ def read_preprocess_cfg(preprocess_conf_file):
augment_conf['num_t_mask'] = process['n_mask']
augment_conf['t_inplace'] = process['inplace']
augment_conf['t_replace_with_zero'] = process['replace_with_zero']
- return augment_conf
+ return augment_conf
+
class StreamDataLoader():
def __init__(self,
@@ -95,12 +96,12 @@ class StreamDataLoader():
frame_length=25,
frame_shift=10,
dither=0.0,
- minlen_in: float=0.0,
+ minlen_in: float=0.0,
maxlen_in: float=float('inf'),
minlen_out: float=0.0,
maxlen_out: float=float('inf'),
resample_rate: int=16000,
- shuffle_size: int=10000,
+ shuffle_size: int=10000,
sort_size: int=1000,
n_iter_processes: int=1,
prefetch_factor: int=2,
@@ -116,11 +117,11 @@ class StreamDataLoader():
text_featurizer = TextFeaturizer(unit_type, vocab_filepath)
symbol_table = text_featurizer.vocab_dict
- self.feat_dim = num_mel_bins
- self.vocab_size = text_featurizer.vocab_size
-
+ self.feat_dim = num_mel_bins
+ self.vocab_size = text_featurizer.vocab_size
+
augment_conf = read_preprocess_cfg(preprocess_conf)
-
+
# The list of shard
shardlist = []
with open(manifest_file, "r") as f:
@@ -128,58 +129,68 @@ class StreamDataLoader():
shardlist.append(line.strip())
world_size = 1
try:
- world_size = paddle.distributed.get_world_size()
+ world_size = paddle.distributed.get_world_size()
except Exception as e:
logger.warninig(e)
- logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1")
- assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...")
+ logger.warninig(
+ "can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
+ )
+ assert (len(shardlist) >= world_size,
+ "the length of shard list should >= number of gpus/xpus/...")
- update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0))
+ update_n_iter_processes = int(
+ max(min(len(shardlist) / world_size - 1, self.n_iter_processes), 0))
logger.info(f"update_n_iter_processes {update_n_iter_processes}")
if update_n_iter_processes != self.n_iter_processes:
- self.n_iter_processes = update_n_iter_processes
+ self.n_iter_processes = update_n_iter_processes
logger.info(f"change nun_workers to {self.n_iter_processes}")
if self.dist_sampler:
base_dataset = streamdata.DataPipeline(
- streamdata.SimpleShardList(shardlist),
- streamdata.split_by_node if train_mode else streamdata.placeholder(),
+ streamdata.SimpleShardList(shardlist), streamdata.split_by_node
+ if train_mode else streamdata.placeholder(),
streamdata.split_by_worker,
- streamdata.tarfile_to_samples(streamdata.reraise_exception)
- )
+ streamdata.tarfile_to_samples(streamdata.reraise_exception))
else:
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
streamdata.split_by_worker,
- streamdata.tarfile_to_samples(streamdata.reraise_exception)
- )
+ streamdata.tarfile_to_samples(streamdata.reraise_exception))
self.dataset = base_dataset.append_list(
streamdata.audio_tokenize(symbol_table),
- streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out),
+ streamdata.audio_data_filter(
+ frame_shift=frame_shift,
+ max_length=maxlen_in,
+ min_length=minlen_in,
+ token_max_length=maxlen_out,
+ token_min_length=minlen_out),
streamdata.audio_resample(resample_rate=resample_rate),
- streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
- streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
+ streamdata.audio_compute_fbank(
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither),
+ streamdata.audio_spec_aug(**augment_conf)
+ if train_mode else streamdata.placeholder(
+ ), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata.shuffle(shuffle_size),
streamdata.sort(sort_size=sort_size),
streamdata.batched(batch_size),
streamdata.audio_padding(),
- streamdata.audio_cmvn(cmvn_file)
- )
+ streamdata.audio_cmvn(cmvn_file))
if paddle.__version__ >= '2.3.2':
self.loader = streamdata.WebLoader(
- self.dataset,
- num_workers=self.n_iter_processes,
- prefetch_factor = self.prefetch_factor,
- batch_size=None
- )
+ self.dataset,
+ num_workers=self.n_iter_processes,
+ prefetch_factor=self.prefetch_factor,
+ batch_size=None)
else:
self.loader = streamdata.WebLoader(
- self.dataset,
- num_workers=self.n_iter_processes,
- batch_size=None
- )
+ self.dataset,
+ num_workers=self.n_iter_processes,
+ batch_size=None)
def __iter__(self):
return self.loader.__iter__()
@@ -188,7 +199,9 @@ class StreamDataLoader():
return self.__iter__()
def __len__(self):
- logger.info("Stream dataloader does not support calculate the length of the dataset")
+ logger.info(
+ "Stream dataloader does not support calculate the length of the dataset"
+ )
return -1
@@ -347,7 +360,7 @@ class DataLoaderFactory():
config['train_mode'] = True
elif mode == 'valid':
config['manifest'] = config.dev_manifest
- config['train_mode'] = False
+ config['train_mode'] = False
elif model == 'test' or mode == 'align':
config['manifest'] = config.test_manifest
config['train_mode'] = False
@@ -358,30 +371,31 @@ class DataLoaderFactory():
config['maxlen_out'] = float('inf')
config['dist_sampler'] = False
else:
- raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
- return StreamDataLoader(
- manifest_file=config.manifest,
- train_mode=config.train_mode,
- unit_type=config.unit_type,
- preprocess_conf=config.preprocess_config,
- batch_size=config.batch_size,
- num_mel_bins=config.feat_dim,
- frame_length=config.window_ms,
- frame_shift=config.stride_ms,
- dither=config.dither,
- minlen_in=config.minlen_in,
- maxlen_in=config.maxlen_in,
- minlen_out=config.minlen_out,
- maxlen_out=config.maxlen_out,
- resample_rate=config.resample_rate,
- shuffle_size=config.shuffle_size,
- sort_size=config.sort_size,
- n_iter_processes=config.num_workers,
- prefetch_factor=config.prefetch_factor,
- dist_sampler=config.dist_sampler,
- cmvn_file=config.cmvn_file,
- vocab_filepath=config.vocab_filepath,
+ raise KeyError(
+ "not valid mode type!!, please input one of 'train, valid, test, align'"
)
+ return StreamDataLoader(
+ manifest_file=config.manifest,
+ train_mode=config.train_mode,
+ unit_type=config.unit_type,
+ preprocess_conf=config.preprocess_config,
+ batch_size=config.batch_size,
+ num_mel_bins=config.feat_dim,
+ frame_length=config.window_ms,
+ frame_shift=config.stride_ms,
+ dither=config.dither,
+ minlen_in=config.minlen_in,
+ maxlen_in=config.maxlen_in,
+ minlen_out=config.minlen_out,
+ maxlen_out=config.maxlen_out,
+ resample_rate=config.resample_rate,
+ shuffle_size=config.shuffle_size,
+ sort_size=config.sort_size,
+ n_iter_processes=config.num_workers,
+ prefetch_factor=config.prefetch_factor,
+ dist_sampler=config.dist_sampler,
+ cmvn_file=config.cmvn_file,
+ vocab_filepath=config.vocab_filepath, )
else:
if mode == 'train':
config['manifest'] = config.train_manifest
@@ -410,7 +424,7 @@ class DataLoaderFactory():
config['train_mode'] = False
config['sortagrad'] = False
config['batch_size'] = config.get('decode', dict()).get(
- 'decode_batch_size', 1)
+ 'decode_batch_size', 1)
config['maxlen_in'] = float('inf')
config['maxlen_out'] = float('inf')
config['minibatches'] = 0
@@ -426,8 +440,10 @@ class DataLoaderFactory():
config['dist_sampler'] = False
config['shortest_first'] = False
else:
- raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
-
+ raise KeyError(
+ "not valid mode type!!, please input one of 'train, valid, test, align'"
+ )
+
return BatchDataLoader(
json_file=config.manifest,
train_mode=config.train_mode,
@@ -449,4 +465,3 @@ class DataLoaderFactory():
num_encs=config.num_encs,
dist_sampler=config.dist_sampler,
shortest_first=config.shortest_first)
-
diff --git a/paddlespeech/s2t/io/sampler.py b/paddlespeech/s2t/io/sampler.py
index ac55af123..89752bb9f 100644
--- a/paddlespeech/s2t/io/sampler.py
+++ b/paddlespeech/s2t/io/sampler.py
@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
- batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
+ batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py
index e86bbedfa..8a811a52b 100644
--- a/paddlespeech/s2t/models/u2_st/u2_st.py
+++ b/paddlespeech/s2t/models/u2_st/u2_st.py
@@ -26,6 +26,8 @@ import paddle
from paddle import jit
from paddle import nn
+from paddlespeech.audio.utils.tensor_utils import add_sos_eos
+from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
@@ -38,8 +40,6 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.log import Log
-from paddlespeech.audio.utils.tensor_utils import add_sos_eos
-from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ["U2STModel", "U2STInferModel"]
@@ -401,8 +401,8 @@ class U2STBaseModel(nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
- att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
- cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
+ att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
+ cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
@@ -435,8 +435,8 @@ class U2STBaseModel(nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
- return self.encoder.forward_chunk(
- xs, offset, required_cache_size, att_cache, cnn_cache)
+ return self.encoder.forward_chunk(xs, offset, required_cache_size,
+ att_cache, cnn_cache)
# @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
diff --git a/paddlespeech/s2t/modules/align.py b/paddlespeech/s2t/modules/align.py
index cacda2461..34d796145 100644
--- a/paddlespeech/s2t/modules/align.py
+++ b/paddlespeech/s2t/modules/align.py
@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
+
import paddle
from paddle import nn
-import math
"""
To align the initializer between paddle and torch,
the API below are set defalut initializer with priority higger than global initializer.
@@ -81,10 +82,18 @@ class Linear(nn.Linear):
name=None):
if weight_attr is None:
if global_init_type == "kaiming_uniform":
- weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
+ weight_attr = paddle.ParamAttr(
+ initializer=nn.initializer.KaimingUniform(
+ fan_in=None,
+ negative_slope=math.sqrt(5),
+ nonlinearity='leaky_relu'))
if bias_attr is None:
if global_init_type == "kaiming_uniform":
- bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
+ bias_attr = paddle.ParamAttr(
+ initializer=nn.initializer.KaimingUniform(
+ fan_in=None,
+ negative_slope=math.sqrt(5),
+ nonlinearity='leaky_relu'))
super(Linear, self).__init__(in_features, out_features, weight_attr,
bias_attr, name)
@@ -104,10 +113,18 @@ class Conv1D(nn.Conv1D):
data_format='NCL'):
if weight_attr is None:
if global_init_type == "kaiming_uniform":
- weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
+ weight_attr = paddle.ParamAttr(
+ initializer=nn.initializer.KaimingUniform(
+ fan_in=None,
+ negative_slope=math.sqrt(5),
+ nonlinearity='leaky_relu'))
if bias_attr is None:
if global_init_type == "kaiming_uniform":
- bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
+ bias_attr = paddle.ParamAttr(
+ initializer=nn.initializer.KaimingUniform(
+ fan_in=None,
+ negative_slope=math.sqrt(5),
+ nonlinearity='leaky_relu'))
super(Conv1D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format)
@@ -128,10 +145,18 @@ class Conv2D(nn.Conv2D):
data_format='NCHW'):
if weight_attr is None:
if global_init_type == "kaiming_uniform":
- weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
+ weight_attr = paddle.ParamAttr(
+ initializer=nn.initializer.KaimingUniform(
+ fan_in=None,
+ negative_slope=math.sqrt(5),
+ nonlinearity='leaky_relu'))
if bias_attr is None:
if global_init_type == "kaiming_uniform":
- bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu'))
+ bias_attr = paddle.ParamAttr(
+ initializer=nn.initializer.KaimingUniform(
+ fan_in=None,
+ negative_slope=math.sqrt(5),
+ nonlinearity='leaky_relu'))
super(Conv2D, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, padding_mode, weight_attr, bias_attr, data_format)
diff --git a/paddlespeech/s2t/modules/attention.py b/paddlespeech/s2t/modules/attention.py
index b6d615867..ee6dd7fa3 100644
--- a/paddlespeech/s2t/modules/attention.py
+++ b/paddlespeech/s2t/modules/attention.py
@@ -83,11 +83,12 @@ class MultiHeadedAttention(nn.Layer):
return q, k, v
- def forward_attention(self,
- value: paddle.Tensor,
+ def forward_attention(
+ self,
+ value: paddle.Tensor,
scores: paddle.Tensor,
- mask: paddle.Tensor = paddle.ones([0, 0, 0], dtype=paddle.bool),
- ) -> paddle.Tensor:
+ mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
+ ) -> paddle.Tensor:
"""Compute attention context vector.
Args:
value (paddle.Tensor): Transformed value, size
@@ -108,7 +109,7 @@ class MultiHeadedAttention(nn.Layer):
# When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
- if paddle.shape(mask)[2] > 0: # time2 > 0
+ if paddle.shape(mask)[2] > 0: # time2 > 0
mask = mask.unsqueeze(1).equal(0) # (batch, 1, *, time2)
# for last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :paddle.shape(scores)[-1]]
@@ -131,9 +132,9 @@ class MultiHeadedAttention(nn.Layer):
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
- mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
- pos_emb: paddle.Tensor = paddle.empty([0]),
- cache: paddle.Tensor = paddle.zeros([0,0,0,0])
+ mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
+ pos_emb: paddle.Tensor=paddle.empty([0]),
+ cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute scaled dot product attention.
Args:
@@ -247,9 +248,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
- mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
- pos_emb: paddle.Tensor = paddle.empty([0]),
- cache: paddle.Tensor = paddle.zeros([0,0,0,0])
+ mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
+ pos_emb: paddle.Tensor=paddle.empty([0]),
+ cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0])
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
diff --git a/paddlespeech/s2t/modules/conformer_convolution.py b/paddlespeech/s2t/modules/conformer_convolution.py
index c384b9c78..837ad470d 100644
--- a/paddlespeech/s2t/modules/conformer_convolution.py
+++ b/paddlespeech/s2t/modules/conformer_convolution.py
@@ -106,11 +106,12 @@ class ConvolutionModule(nn.Layer):
)
self.activation = activation
- def forward(self,
- x: paddle.Tensor,
- mask_pad: paddle.Tensor= paddle.ones([0,0,0], dtype=paddle.bool),
- cache: paddle.Tensor= paddle.zeros([0,0,0]),
- ) -> Tuple[paddle.Tensor, paddle.Tensor]:
+ def forward(
+ self,
+ x: paddle.Tensor,
+ mask_pad: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
+ cache: paddle.Tensor=paddle.zeros([0, 0, 0]),
+ ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute convolution module.
Args:
x (paddle.Tensor): Input tensor (#batch, time, channels).
@@ -127,11 +128,11 @@ class ConvolutionModule(nn.Layer):
x = x.transpose([0, 2, 1]) # [B, C, T]
# mask batch padding
- if paddle.shape(mask_pad)[2] > 0: # time > 0
+ if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0)
if self.lorder > 0:
- if paddle.shape(cache)[2] == 0: # cache_t == 0
+ if paddle.shape(cache)[2] == 0: # cache_t == 0
x = nn.functional.pad(
x, [self.lorder, 0], 'constant', 0.0, data_format='NCL')
else:
@@ -161,7 +162,7 @@ class ConvolutionModule(nn.Layer):
x = self.pointwise_conv2(x)
# mask batch padding
- if paddle.shape(mask_pad)[2] > 0: # time > 0
+ if paddle.shape(mask_pad)[2] > 0: # time > 0
x = x.masked_fill(mask_pad, 0.0)
x = x.transpose([0, 2, 1]) # [B, T, C]
diff --git a/paddlespeech/s2t/modules/encoder.py b/paddlespeech/s2t/modules/encoder.py
index bff2d69bb..dc7b6666a 100644
--- a/paddlespeech/s2t/modules/encoder.py
+++ b/paddlespeech/s2t/modules/encoder.py
@@ -190,9 +190,9 @@ class BaseEncoder(nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
- att_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
- cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0]),
- att_mask: paddle.Tensor = paddle.ones([0,0,0], dtype=paddle.bool),
+ att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
+ cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
+ att_mask: paddle.Tensor=paddle.ones([0, 0, 0], dtype=paddle.bool),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Forward just one chunk
Args:
@@ -227,7 +227,7 @@ class BaseEncoder(nn.Layer):
xs = self.global_cmvn(xs)
# before embed, xs=(B, T, D1), pos_emb=(B=1, T, D)
- xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
+ xs, pos_emb, _ = self.embed(xs, tmp_masks, offset=offset)
# after embed, xs=(B=1, chunk_size, hidden-dim)
elayers = paddle.shape(att_cache)[0]
@@ -252,14 +252,16 @@ class BaseEncoder(nn.Layer):
# att_cache[i:i+1] = (1, head, cache_t1, d_k*2)
# cnn_cache[i:i+1] = (1, B=1, hidden-dim, cache_t2)
xs, _, new_att_cache, new_cnn_cache = layer(
- xs, att_mask, pos_emb,
- att_cache=att_cache[i:i+1] if elayers > 0 else att_cache,
- cnn_cache=cnn_cache[i:i+1] if paddle.shape(cnn_cache)[0] > 0 else cnn_cache,
- )
+ xs,
+ att_mask,
+ pos_emb,
+ att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
+ cnn_cache=cnn_cache[i:i + 1]
+ if paddle.shape(cnn_cache)[0] > 0 else cnn_cache, )
# new_att_cache = (1, head, attention_key_size, d_k*2)
# new_cnn_cache = (B=1, hidden-dim, cache_t2)
- r_att_cache.append(new_att_cache[:,:, next_cache_start:, :])
- r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
+ r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
+ r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) # add elayer dim
if self.normalize_before:
xs = self.after_norm(xs)
@@ -270,7 +272,6 @@ class BaseEncoder(nn.Layer):
r_cnn_cache = paddle.concat(r_cnn_cache, axis=0)
return xs, r_att_cache, r_cnn_cache
-
def forward_chunk_by_chunk(
self,
xs: paddle.Tensor,
@@ -315,8 +316,8 @@ class BaseEncoder(nn.Layer):
num_frames = xs.shape[1]
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
- att_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
- cnn_cache: paddle.Tensor = paddle.zeros([0,0,0,0])
+ att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0])
+ cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0])
outputs = []
offset = 0
@@ -326,7 +327,7 @@ class BaseEncoder(nn.Layer):
chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = self.forward_chunk(
- chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
+ chunk_xs, offset, required_cache_size, att_cache, cnn_cache)
outputs.append(y)
offset += y.shape[1]
diff --git a/paddlespeech/s2t/modules/initializer.py b/paddlespeech/s2t/modules/initializer.py
index cdcf2e052..e37837d2f 100644
--- a/paddlespeech/s2t/modules/initializer.py
+++ b/paddlespeech/s2t/modules/initializer.py
@@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
+
class DefaultInitializerContext(object):
"""
egs:
diff --git a/paddlespeech/server/bin/paddlespeech_server.py b/paddlespeech/server/bin/paddlespeech_server.py
index 175e8ffb6..11f50655f 100644
--- a/paddlespeech/server/bin/paddlespeech_server.py
+++ b/paddlespeech/server/bin/paddlespeech_server.py
@@ -18,7 +18,6 @@ from typing import List
import uvicorn
from fastapi import FastAPI
-from starlette.middleware.cors import CORSMiddleware
from prettytable import PrettyTable
from starlette.middleware.cors import CORSMiddleware
@@ -46,6 +45,7 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"])
+
@cli_server_register(
name='paddlespeech_server.start', description='Start the service')
class ServerExecutor(BaseExecutor):
diff --git a/paddlespeech/server/engine/asr/online/ctc_endpoint.py b/paddlespeech/server/engine/asr/online/ctc_endpoint.py
index b87dbe805..1b8ad1cb7 100644
--- a/paddlespeech/server/engine/asr/online/ctc_endpoint.py
+++ b/paddlespeech/server/engine/asr/online/ctc_endpoint.py
@@ -102,8 +102,10 @@ class OnlineCTCEndpoint:
assert self.num_frames_decoded >= self.trailing_silence_frames
assert self.frame_shift_in_ms > 0
-
- decoding_something = (self.num_frames_decoded > self.trailing_silence_frames) and decoding_something
+
+ decoding_something = (
+ self.num_frames_decoded > self.trailing_silence_frames
+ ) and decoding_something
utterance_length = self.num_frames_decoded * self.frame_shift_in_ms
trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms
diff --git a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py
index ab4f11305..6daae5be3 100644
--- a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py
+++ b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py
@@ -21,12 +21,12 @@ import paddle
from numpy import float32
from yacs.config import CfgNode
+from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
-from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.server.utils import onnx_infer
diff --git a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
index 182e64180..0fd5d1bc6 100644
--- a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
+++ b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py
@@ -21,10 +21,10 @@ import paddle
from numpy import float32
from yacs.config import CfgNode
+from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
-from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.utility import UpdateConfig
diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py
index 4df38f09d..269a33ba7 100644
--- a/paddlespeech/server/engine/asr/online/python/asr_engine.py
+++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py
@@ -21,10 +21,10 @@ import paddle
from numpy import float32
from yacs.config import CfgNode
+from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource
-from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
@@ -130,8 +130,8 @@ class PaddleASRConnectionHanddler:
## conformer
# cache for conformer online
- self.att_cache = paddle.zeros([0,0,0,0])
- self.cnn_cache = paddle.zeros([0,0,0,0])
+ self.att_cache = paddle.zeros([0, 0, 0, 0])
+ self.cnn_cache = paddle.zeros([0, 0, 0, 0])
self.encoder_out = None
# conformer decoding state
@@ -474,9 +474,10 @@ class PaddleASRConnectionHanddler:
# cur chunk
chunk_xs = self.cached_feat[:, cur:end, :]
# forward chunk
- (y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk(
- chunk_xs, self.offset, required_cache_size,
- self.att_cache, self.cnn_cache)
+ (y, self.att_cache,
+ self.cnn_cache) = self.model.encoder.forward_chunk(
+ chunk_xs, self.offset, required_cache_size, self.att_cache,
+ self.cnn_cache)
outputs.append(y)
# update the global offset, in decoding frame unit
diff --git a/paddlespeech/t2s/exps/ernie_sat/align.py b/paddlespeech/t2s/exps/ernie_sat/align.py
index 529a8221c..513c57e70 100755
--- a/paddlespeech/t2s/exps/ernie_sat/align.py
+++ b/paddlespeech/t2s/exps/ernie_sat/align.py
@@ -19,9 +19,9 @@ import librosa
import numpy as np
import pypinyin
from praatio import textgrid
-from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
-from paddlespeech.t2s.exps.ernie_sat.utils import get_dict
+from paddlespeech.t2s.exps.ernie_sat.utils import get_dict
+from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
DICT_EN = 'tools/aligner/cmudict-0.7b'
DICT_ZH = 'tools/aligner/simple.lexicon'
@@ -30,6 +30,7 @@ MODEL_DIR_ZH = 'tools/aligner/aishell3_model.zip'
MFA_PATH = 'tools/montreal-forced-aligner/bin'
os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH']
+
def _get_max_idx(dic):
return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]
@@ -106,11 +107,11 @@ def alignment(wav_path: str,
wav_name = os.path.basename(wav_path)
utt = wav_name.split('.')[0]
# prepare data for MFA
- tmp_name = get_tmp_name(text=text)
+ tmp_name = get_tmp_name(text=text)
tmpbase = './tmp_dir/' + tmp_name
tmpbase = Path(tmpbase)
tmpbase.mkdir(parents=True, exist_ok=True)
- print("tmp_name in alignment:",tmp_name)
+ print("tmp_name in alignment:", tmp_name)
shutil.copyfile(wav_path, tmpbase / wav_name)
txt_name = utt + '.txt'
diff --git a/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py
index 95b07367c..fac4cef87 100644
--- a/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py
+++ b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py
@@ -15,31 +15,24 @@ import librosa
import numpy as np
import soundfile as sf
+from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
+from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans
from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs
from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor
from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy
-from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
+from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
from paddlespeech.t2s.exps.syn_utils import get_frontend
-from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.exps.syn_utils import norm
-from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name
-
-
-
-
def _p2id(self, phonemes: List[str]) -> np.ndarray:
# replace unk phone with sp
- phonemes = [
- phn if phn in vocab_phones else "sp" for phn in phonemes
- ]
+ phonemes = [phn if phn in vocab_phones else "sp" for phn in phonemes]
phone_ids = [vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
-
def prep_feats_with_dur(wav_path: str,
old_str: str='',
new_str: str='',
@@ -152,8 +145,6 @@ def prep_feats_with_dur(wav_path: str,
return outs
-
-
def prep_feats(wav_path: str,
old_str: str='',
new_str: str='',
@@ -188,32 +179,32 @@ def prep_feats(wav_path: str,
mel = mel_extractor.get_log_mel_fbank(wav)
erniesat_mean, erniesat_std = np.load(erniesat_stat)
normed_mel = norm(mel, erniesat_mean, erniesat_std)
- tmp_name = get_tmp_name(text=old_str)
+ tmp_name = get_tmp_name(text=old_str)
tmpbase = './tmp_dir/' + tmp_name
tmpbase = Path(tmpbase)
tmpbase.mkdir(parents=True, exist_ok=True)
- print("tmp_name in synthesize_e2e:",tmp_name)
+ print("tmp_name in synthesize_e2e:", tmp_name)
mel_path = tmpbase / 'mel.npy'
- print("mel_path:",mel_path)
+ print("mel_path:", mel_path)
np.save(mel_path, logmel)
durations = [e - s for e, s in zip(mfa_end, mfa_start)]
- datum={
- "utt_id": utt_id,
- "spk_id": 0,
- "text": text,
- "text_lengths": len(text),
- "speech_lengths": 115,
- "durations": durations,
- "speech": mel_path,
- "align_start": mfa_start,
+ datum = {
+ "utt_id": utt_id,
+ "spk_id": 0,
+ "text": text,
+ "text_lengths": len(text),
+ "speech_lengths": 115,
+ "durations": durations,
+ "speech": mel_path,
+ "align_start": mfa_start,
"align_end": mfa_end,
"span_bdy": span_bdy
}
batch = collate_fn([datum])
- print("batch:",batch)
+ print("batch:", batch)
return batch, old_span_bdy, new_span_bdy
@@ -240,8 +231,6 @@ def decode_with_model(mlm_model: nn.Layer,
fs=fs,
n_shift=n_shift,
token_list=token_list)
-
-
feats = collate_fn(batch)[1]
@@ -275,7 +264,6 @@ if __name__ == '__main__':
# for synthesize
append_str = "do you love me i love you so much"
new_str = old_str + append_str
-
'''
outs = prep_feats_with_dur(
wav_path=wav_path,
@@ -306,7 +294,7 @@ if __name__ == '__main__':
with open(erniesat_config) as f:
erniesat_config = CfgNode(yaml.safe_load(f))
-
+
erniesat_stat = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy"
# Extractor
@@ -319,16 +307,14 @@ if __name__ == '__main__':
n_mels=erniesat_config.n_mels,
fmin=erniesat_config.fmin,
fmax=erniesat_config.fmax)
-
-
collate_fn = build_erniesat_collate_fn(
mlm_prob=erniesat_config.mlm_prob,
mean_phn_span=erniesat_config.mean_phn_span,
seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm',
text_masking=False)
-
- phones_dict='/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt'
+
+ phones_dict = '/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt'
vocab_phones = {}
with open(phones_dict, 'rt') as f:
@@ -336,11 +322,9 @@ if __name__ == '__main__':
for phn, id in phn_id:
vocab_phones[phn] = int(id)
- prep_feats(wav_path=wav_path,
- old_str=old_str,
- new_str=new_str,
- fs=fs,
- n_shift=n_shift)
-
-
-
+ prep_feats(
+ wav_path=wav_path,
+ old_str=old_str,
+ new_str=new_str,
+ fs=fs,
+ n_shift=n_shift)
diff --git a/paddlespeech/t2s/exps/ernie_sat/utils.py b/paddlespeech/t2s/exps/ernie_sat/utils.py
index 9169efa36..6805e513c 100644
--- a/paddlespeech/t2s/exps/ernie_sat/utils.py
+++ b/paddlespeech/t2s/exps/ernie_sat/utils.py
@@ -11,32 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
+import os
from pathlib import Path
from typing import Dict
from typing import List
from typing import Union
-import os
import numpy as np
import paddle
import yaml
from yacs.config import CfgNode
-import hashlib
-
from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
+
def _get_user():
return os.path.expanduser('~').split('/')[-1]
+
def str2md5(string):
md5_val = hashlib.md5(string.encode('utf8')).hexdigest()
return md5_val
-def get_tmp_name(text:str):
+
+def get_tmp_name(text: str):
return _get_user() + '_' + str(os.getpid()) + '_' + str2md5(text)
+
def get_dict(dictfile: str):
word2phns_dict = {}
with open(dictfile, 'r') as fid:
diff --git a/paddlespeech/t2s/exps/tacotron2/normalize.py b/paddlespeech/t2s/exps/tacotron2/normalize.py
index 64848f899..6c250a3b9 120000
--- a/paddlespeech/t2s/exps/tacotron2/normalize.py
+++ b/paddlespeech/t2s/exps/tacotron2/normalize.py
@@ -1 +1 @@
-../transformer_tts/normalize.py
\ No newline at end of file
+#../transformer_tts/normalize.py
diff --git a/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py b/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py
index 87f6c7cff..dfd9e6352 100644
--- a/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py
+++ b/paddlespeech/t2s/exps/transformer_tts/preprocess_new.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
+import os
from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter
from pathlib import Path
@@ -24,165 +25,15 @@ import librosa
import numpy as np
import tqdm
import yaml
-from yacs.config import CfgNode as Configuration
-import re
-from paddlespeech.t2s.datasets.get_feats import LogMelFBank
-from paddlespeech.t2s.frontend import English,Chinese
-
-
-def get_lj_sentences(file_name, frontend):
- '''read MFA duration.txt
-
- Args:
- file_name (str or Path)
- Returns:
- Dict: sentence: {'utt': ([char], [int])}
- '''
- f = open(file_name, 'r')
- sentence = {}
- speaker_set = set()
- for line in f:
- line_list = line.strip().split('|')
- utt = line_list[0]
- speaker = utt.split("-")[0][:2]
- speaker_set.add(speaker)
- raw_text = line_list[-1]
- phonemes = frontend.phoneticize(raw_text)
- phonemes = phonemes[1:-1]
- phonemes = [phn for phn in phonemes if not phn.isspace()]
- sentence[utt] = (phonemes, speaker)
- f.close()
- return sentence, speaker_set
-
-def get_csmsc_sentences(file_name,fronten):
- '''read MFA duration.txt
-
- Args:
- file_name (str or Path)
- Returns:
- Dict: sentence: {'utt': ([char], [int])}
- '''
- sentence = {}
- speaker_set = set()
- utt = 'girl'
- with open(file_name, mode='r', encoding='utf-8') as f:
- lines = f.readlines()
- for i in range(len(lines)):
- ann = lines[i]
- if i % 2 == 0:
- head = ann.strip('\n|\t').split('\t')
- body = re.sub(r'[0-9]|#', '', head[-1])
- phonemes = fronten.phoneticize(body)
- phonemes = phonemes[1:-1]
- phonemes = [phn for phn in phonemes if not phn.isspace()]
- sentence[head[0]] = (phonemes, utt)
- speaker_set.add(utt)
- f.close()
- return sentence, speaker_set
-
-def get_input_token(sentence, output_path):
- '''get phone set from training data and save it
-
- Args:
- sentence (Dict): sentence: {'utt': ([char], str)}
- output_path (str or path): path to save phone_id_map
- '''
- phn_token = set()
- for utt in sentence:
- for phn in sentence[utt][0]:
- if phn != "":
- phn_token.add(phn)
- phn_token = list(phn_token)
- phn_token.sort()
- phn_token = ["", ""] + phn_token
- phn_token += [""]
-
- with open(output_path, 'w') as f:
- for i, phn in enumerate(phn_token):
- f.write(phn + ' ' + str(i) + '\n')
-
-
-def get_spk_id_map(speaker_set, output_path):
- speakers = sorted(list(speaker_set))
- with open(output_path, 'w') as f:
- for i, spk in enumerate(speakers):
- f.write(spk + ' ' + str(i) + '\n')
-
-
-def process_sentence(config: Dict[str, Any],
- fp: Path,
- sentences: Dict,
- output_dir: Path,
- mel_extractor=None):
- utt_id = fp.stem
- record = None
- if utt_id in sentences:
- # reading, resampling may occur
- wav, _ = librosa.load(str(fp), sr=config.fs)
- if len(wav.shape) != 1 or np.abs(wav).max() > 1.0:
- return record
- assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
- assert np.abs(wav).max(
- ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
- phones = sentences[utt_id][0]
- speaker = sentences[utt_id][1]
- logmel = mel_extractor.get_log_mel_fbank(wav, base='e')
- # change duration according to mel_length
- num_frames = logmel.shape[0]
- mel_dir = output_dir / "data_speech"
- mel_dir.mkdir(parents=True, exist_ok=True)
- mel_path = mel_dir / (utt_id + "_speech.npy")
- np.save(mel_path, logmel)
- record = {
- "utt_id": utt_id,
- "phones": phones,
- "text_lengths": len(phones),
- "speech_lengths": num_frames,
- "speech": str(mel_path),
- "speaker": speaker
- }
- return record
-
-
-def process_sentences(config,
- fps: List[Path],
- sentences: Dict,
- output_dir: Path,
- mel_extractor=None,
- nprocs: int=1):
-
- if nprocs == 1:
- results = []
- for fp in tqdm.tqdm(fps, total=len(fps)):
- record = process_sentence(
- config=config,
- fp=fp,
- sentences=sentences,
- output_dir=output_dir,
- mel_extractor=mel_extractor)
- if record:
- results.append(record)
- else:
- with ThreadPoolExecutor(nprocs) as pool:
- futures = []
- with tqdm.tqdm(total=len(fps)) as progress:
- for fp in fps:
- future = pool.submit(process_sentence, config, fp,
- sentences, output_dir, mel_extractor)
- future.add_done_callback(lambda p: progress.update())
- futures.append(future)
-
- results = []
- for ft in futures:
- record = ft.result()
- if record:
- results.append(record)
+from yacs.config import CfgNode
- results.sort(key=itemgetter("utt_id"))
- with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
- for item in results:
- writer.write(item)
- print("Done")
+from paddlespeech.t2s.datasets.get_feats import LogMelFBank
+from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
+from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
+from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
+from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
+from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
+from paddlespeech.t2s.utils import str2bool
def main():
@@ -192,71 +43,113 @@ def main():
parser.add_argument(
"--dataset",
- default="csmsc",
+ default="baker",
type=str,
- help="name of dataset, should in {ljspeech,csmsc} now")
+ help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now")
parser.add_argument(
- "--rootdir", default='./BZNSYP/', type=str, help="directory to dataset.")
+ "--rootdir", default=None, type=str, help="directory to dataset.")
parser.add_argument(
"--dumpdir",
type=str,
- default='./dump/',
- #required=True,
+ required=True,
help="directory to dump feature files.")
-
parser.add_argument(
- "--config-path",
- default="./default.yaml",
- type=str,
- help="yaml format configuration file.")
+ "--dur-file", default=None, type=str, help="path to durations.txt.")
+
+ parser.add_argument("--config", type=str, help="transformer config file.")
parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.")
+ parser.add_argument(
+ "--cut-sil",
+ type=str2bool,
+ default=True,
+ help="whether cut sil in the edge of audio")
+
+ parser.add_argument(
+ "--spk_emb_dir",
+ default=None,
+ type=str,
+ help="directory to speaker embedding files.")
args = parser.parse_args()
- config_path = Path(args.config_path).resolve()
- root_dir = Path(args.rootdir).expanduser()
+
+ rootdir = Path(args.rootdir).expanduser()
dumpdir = Path(args.dumpdir).expanduser()
# use absolute path
dumpdir = dumpdir.resolve()
dumpdir.mkdir(parents=True, exist_ok=True)
+ dur_file = Path(args.dur_file).expanduser()
- assert root_dir.is_dir()
+ if args.spk_emb_dir:
+ spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
+ else:
+ spk_emb_dir = None
+
+ assert rootdir.is_dir()
+ assert dur_file.is_file()
- with open(config_path, 'rt') as f:
- _C = yaml.safe_load(f)
- _C = Configuration(_C)
- config = _C.clone()
+ with open(args.config, 'rt') as f:
+ config = CfgNode(yaml.safe_load(f))
+ sentences, speaker_set = get_phn_dur(dur_file)
+
+ merge_silence(sentences)
phone_id_map_path = dumpdir / "phone_id_map.txt"
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
- if args.dataset == "csmsc":
- wav_files = sorted(list((root_dir / "Wave").rglob("*.wav")))
- frontend = Chinese()
- sentences, speaker_set = get_csmsc_sentences(root_dir / "000001-010000.txt", frontend)
- print(speaker_set)
- get_input_token(sentences, phone_id_map_path)
- get_spk_id_map(speaker_set, speaker_id_map_path)
- num_train = 9000
+ get_input_token(sentences, phone_id_map_path, args.dataset)
+ get_spk_id_map(speaker_set, speaker_id_map_path)
+
+ if args.dataset == "baker":
+ wav_files = sorted(list((rootdir / "Wave").rglob("*.wav")))
+ # split data into 3 sections
+ num_train = 9800
num_dev = 100
train_wav_files = wav_files[:num_train]
dev_wav_files = wav_files[num_train:num_train + num_dev]
test_wav_files = wav_files[num_train + num_dev:]
-
- if args.dataset == "ljspeech":
- wav_files = sorted(list((root_dir / "wavs").rglob("*.wav")))
- frontend = English()
- sentences, speaker_set = get_lj_sentences(root_dir / "metadata.csv",frontend)
- get_input_token(sentences, phone_id_map_path)
- get_spk_id_map(speaker_set, speaker_id_map_path)
+ elif args.dataset == "aishell3":
+ sub_num_dev = 5
+ wav_dir = rootdir / "train" / "wav"
+ train_wav_files = []
+ dev_wav_files = []
+ test_wav_files = []
+ for speaker in os.listdir(wav_dir):
+ wav_files = sorted(list((wav_dir / speaker).rglob("*.wav")))
+ if len(wav_files) > 100:
+ train_wav_files += wav_files[:-sub_num_dev * 2]
+ dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
+ test_wav_files += wav_files[-sub_num_dev:]
+ else:
+ train_wav_files += wav_files
+
+ elif args.dataset == "ljspeech":
+ wav_files = sorted(list((rootdir / "wavs").rglob("*.wav")))
# split data into 3 sections
num_train = 12900
num_dev = 100
train_wav_files = wav_files[:num_train]
dev_wav_files = wav_files[num_train:num_train + num_dev]
test_wav_files = wav_files[num_train + num_dev:]
+ elif args.dataset == "vctk":
+ sub_num_dev = 5
+ wav_dir = rootdir / "wav48_silence_trimmed"
+ train_wav_files = []
+ dev_wav_files = []
+ test_wav_files = []
+ for speaker in os.listdir(wav_dir):
+ wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac")))
+ if len(wav_files) > 100:
+ train_wav_files += wav_files[:-sub_num_dev * 2]
+ dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
+ test_wav_files += wav_files[-sub_num_dev:]
+ else:
+ train_wav_files += wav_files
+
+ else:
+ print("dataset should in {baker, aishell3, ljspeech, vctk} now!")
train_dump_dir = dumpdir / "train" / "raw"
train_dump_dir.mkdir(parents=True, exist_ok=True)
@@ -284,7 +177,9 @@ def main():
sentences=sentences,
output_dir=train_dump_dir,
mel_extractor=mel_extractor,
- nprocs=args.num_cpu)
+ nprocs=args.num_cpu,
+ cut_sil=args.cut_sil,
+ spk_emb_dir=spk_emb_dir)
if dev_wav_files:
process_sentences(
config=config,
@@ -292,7 +187,8 @@ def main():
sentences=sentences,
output_dir=dev_dump_dir,
mel_extractor=mel_extractor,
- nprocs=args.num_cpu)
+ cut_sil=args.cut_sil,
+ spk_emb_dir=spk_emb_dir)
if test_wav_files:
process_sentences(
config=config,
@@ -300,7 +196,9 @@ def main():
sentences=sentences,
output_dir=test_dump_dir,
mel_extractor=mel_extractor,
- nprocs=args.num_cpu)
+ nprocs=args.num_cpu,
+ cut_sil=args.cut_sil,
+ spk_emb_dir=spk_emb_dir)
if __name__ == "__main__":
diff --git a/paddlespeech/t2s/exps/transformer_tts/synthesize.py b/paddlespeech/t2s/exps/transformer_tts/synthesize.py
index 7b6b1873f..b66bbdc62 100644
--- a/paddlespeech/t2s/exps/transformer_tts/synthesize.py
+++ b/paddlespeech/t2s/exps/transformer_tts/synthesize.py
@@ -50,11 +50,14 @@ def evaluate(args, acoustic_model_config, vocoder_config):
model.set_state_dict(
paddle.load(args.transformer_tts_checkpoint)["main_params"])
model.eval()
+
# remove ".pdparams" in waveflow_checkpoint
- vocoder_checkpoint_path = args.waveflow_checkpoint[:-9] if args.waveflow_checkpoint.endswith(
- ".pdparams") else args.waveflow_checkpoint
- vocoder = ConditionalWaveFlow.from_pretrained(vocoder_config,
- vocoder_checkpoint_path)
+ vocoder = get_voc_inference(
+ voc=args.voc,
+ voc_config=vocoder_config,
+ voc_ckpt=args.voc_ckpt,
+ voc_stat=args.voc_stat)
+
layer_tools.recursively_remove_weight_norm(vocoder)
vocoder.eval()
print("model done!")
@@ -78,9 +81,8 @@ def evaluate(args, acoustic_model_config, vocoder_config):
with paddle.no_grad():
mel = transformer_tts_inference(text)
# mel shape is (T, feats) and waveflow's input shape is (batch, feats, T)
- mel = mel.unsqueeze(0).transpose([0, 2, 1])
# wavflow's output shape is (B, T)
- wav = vocoder.infer(mel)[0]
+ wav = vocoder(mel)
sf.write(
str(output_dir / (utt_id + ".wav")),
@@ -106,18 +108,34 @@ def main():
type=str,
help="mean and standard deviation used to normalize spectrogram when training transformer tts."
)
+
+ # vocoder
+ parser.add_argument(
+ '--voc',
+ type=str,
+ default='pwgan_csmsc',
+ choices=[
+ 'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
+ 'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc',
+ 'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk',
+ 'style_melgan_csmsc'
+ ],
+ help='Choose vocoder type of tts task.')
parser.add_argument(
- "--waveflow-config", type=str, help="waveflow config file.")
- # not normalize when training waveflow
+ '--voc_config', type=str, default=None, help='Config of voc.')
parser.add_argument(
- "--waveflow-checkpoint", type=str, help="waveflow checkpoint to load.")
+ '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument(
- "--phones-dict", type=str, default=None, help="phone vocabulary file.")
-
- parser.add_argument("--test-metadata", type=str, help="test metadata.")
- parser.add_argument("--output-dir", type=str, help="output dir.")
+ "--voc_stat",
+ type=str,
+ default=None,
+ help="mean and standard deviation used to normalize spectrogram when training voc."
+ )
+ # other
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
+ parser.add_argument("--test_metadata", type=str, help="test metadata.")
+ parser.add_argument("--output_dir", type=str, help="output dir.")
args = parser.parse_args()
@@ -130,16 +148,16 @@ def main():
with open(args.transformer_tts_config) as f:
transformer_tts_config = CfgNode(yaml.safe_load(f))
- with open(args.waveflow_config) as f:
- waveflow_config = CfgNode(yaml.safe_load(f))
+ with open(args.voc_config) as f:
+ voc_config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(transformer_tts_config)
- print(waveflow_config)
+ print(voc_config)
- evaluate(args, transformer_tts_config, waveflow_config)
+ evaluate(args, transformer_tts_config, voc_config)
if __name__ == "__main__":
diff --git a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py
index 0cd7d224e..716deb7bc 100644
--- a/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py
+++ b/paddlespeech/t2s/exps/transformer_tts/synthesize_e2e.py
@@ -21,6 +21,7 @@ import soundfile as sf
import yaml
from yacs.config import CfgNode
+from paddlespeech.t2s.frontend import Chinese
from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.models.transformer_tts import TransformerTTS
from paddlespeech.t2s.models.transformer_tts import TransformerTTSInference
@@ -59,15 +60,15 @@ def evaluate(args, acoustic_model_config, vocoder_config):
model.eval()
# remove ".pdparams" in waveflow_checkpoint
- vocoder_checkpoint_path = args.waveflow_checkpoint[:-9] if args.waveflow_checkpoint.endswith(
- ".pdparams") else args.waveflow_checkpoint
- vocoder = ConditionalWaveFlow.from_pretrained(vocoder_config,
- vocoder_checkpoint_path)
- layer_tools.recursively_remove_weight_norm(vocoder)
+ vocoder = get_voc_inference(
+ voc=args.voc,
+ voc_config=vocoder_config,
+ voc_ckpt=args.voc_ckpt,
+ voc_stat=args.voc_stat)
vocoder.eval()
print("model done!")
- frontend = English()
+ frontend = Chinese()
print("frontend done!")
stat = np.load(args.transformer_tts_stat)
@@ -90,11 +91,10 @@ def evaluate(args, acoustic_model_config, vocoder_config):
phones = [phn if phn in phone_id_map else "," for phn in phones]
phone_ids = [phone_id_map[phn] for phn in phones]
with paddle.no_grad():
- mel = transformer_tts_inference(paddle.to_tensor(phone_ids))
- # mel shape is (T, feats) and waveflow's input shape is (batch, feats, T)
- mel = mel.unsqueeze(0).transpose([0, 2, 1])
- # wavflow's output shape is (B, T)
- wav = vocoder.infer(mel)[0]
+ tensor_phone_ids = paddle.to_tensor(phone_ids)
+ mel = transformer_tts_inference(tensor_phone_ids)
+
+ wav = vocoder(mel)
sf.write(
str(output_dir / (utt_id + ".wav")),
@@ -120,23 +120,51 @@ def main():
type=str,
help="mean and standard deviation used to normalize spectrogram when training transformer tts."
)
+
+ # vocoder
+ parser.add_argument(
+ '--voc',
+ type=str,
+ default='pwgan_csmsc',
+ choices=[
+ 'pwgan_csmsc',
+ 'pwgan_ljspeech',
+ 'pwgan_aishell3',
+ 'pwgan_vctk',
+ 'mb_melgan_csmsc',
+ 'style_melgan_csmsc',
+ 'hifigan_csmsc',
+ 'hifigan_ljspeech',
+ 'hifigan_aishell3',
+ 'hifigan_vctk',
+ 'wavernn_csmsc',
+ ],
+ help='Choose vocoder type of tts task.')
parser.add_argument(
- "--waveflow-config", type=str, help="waveflow config file.")
- # not normalize when training waveflow
+ '--voc_config', type=str, default=None, help='Config of voc.')
parser.add_argument(
- "--waveflow-checkpoint", type=str, help="waveflow checkpoint to load.")
+ '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.')
parser.add_argument(
- "--phones-dict",
+ "--voc_stat",
type=str,
- default="phone_id_map.txt",
- help="phone vocabulary file.")
+ default=None,
+ help="mean and standard deviation used to normalize spectrogram when training voc."
+ )
+ # other
parser.add_argument(
- "--text",
+ "--phones_dict", type=str, default=None, help="phone vocabulary file.")
+ parser.add_argument(
+ '--lang',
type=str,
- help="text to synthesize, a 'utt_id sentence' pair per line.")
- parser.add_argument("--output-dir", type=str, help="output dir.")
+ default='zh',
+ help='Choose model language. zh or en or mix')
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
+ parser.add_argument(
+ "--text",
+ type=str,
+ help="text to synthesize, a 'utt_id sentence' pair per line.")
+ parser.add_argument("--output_dir", type=str, help="output dir.")
args = parser.parse_args()
@@ -149,16 +177,16 @@ def main():
with open(args.transformer_tts_config) as f:
transformer_tts_config = CfgNode(yaml.safe_load(f))
- with open(args.waveflow_config) as f:
- waveflow_config = CfgNode(yaml.safe_load(f))
+ with open(args.voc_config) as f:
+ voc_config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(transformer_tts_config)
- print(waveflow_config)
+ print(voc_config)
- evaluate(args, transformer_tts_config, waveflow_config)
+ evaluate(args, transformer_tts_config, voc_config)
if __name__ == "__main__":
diff --git a/paddlespeech/t2s/modules/transformer/repeat.py b/paddlespeech/t2s/modules/transformer/repeat.py
index 43d11e9f9..d7f45e676 100644
--- a/paddlespeech/t2s/modules/transformer/repeat.py
+++ b/paddlespeech/t2s/modules/transformer/repeat.py
@@ -38,4 +38,4 @@ def repeat(N, fn):
Returns:
MultiSequential: Repeated model instance.
"""
- return MultiSequential(* [fn(n) for n in range(N)])
+ return MultiSequential(*[fn(n) for n in range(N)])
diff --git a/setup.py b/setup.py
index 1cc82fa76..b9e1f7c73 100644
--- a/setup.py
+++ b/setup.py
@@ -76,12 +76,7 @@ base = [
"pybind11",
]
-server = [
- "fastapi",
- "uvicorn",
- "pattern_singleton",
- "websockets"
-]
+server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"]
requirements = {
"install":
diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py
index 4426d1be8..386b4d9b8 100755
--- a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py
+++ b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py
@@ -496,6 +496,10 @@ class SymbolicShapeInference:
+
+
+
+
'Attention', 'BiasGelu', \
'EmbedLayerNormalization', \
'FastGelu', 'Gelu', 'LayerNormalization', \
@@ -514,8 +518,8 @@ class SymbolicShapeInference:
if (get_opset(self.out_mp_) >= 9) and node.op_type in ['Unsqueeze']:
initializers = [
self.initializers_[name] for name in node.input
- if (name in self.initializers_ and
- name not in self.graph_inputs_)
+ if (name in self.initializers_ and name not in
+ self.graph_inputs_)
]
# run single node inference with self.known_vi_ shapes
@@ -601,8 +605,8 @@ class SymbolicShapeInference:
for o in symbolic_shape_inference.out_mp_.graph.output
]
subgraph_new_symbolic_dims = set([
- d for s in subgraph_shapes if s for d in s
- if type(d) == str and not d in self.symbolic_dims_
+ d for s in subgraph_shapes
+ if s for d in s if type(d) == str and not d in self.symbolic_dims_
])
new_dims = {}
for d in subgraph_new_symbolic_dims:
@@ -729,8 +733,9 @@ class SymbolicShapeInference:
for d, s in zip(sympy_shape[-rank:], strides)
]
total_pads = [
- max(0, (k - s) if r == 0 else (k - r)) for k, s, r in
- zip(effective_kernel_shape, strides, residual)
+ max(0, (k - s) if r == 0 else (k - r))
+ for k, s, r in zip(effective_kernel_shape, strides,
+ residual)
]
except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
total_pads = [
@@ -1276,8 +1281,9 @@ class SymbolicShapeInference:
if pads is not None:
assert len(pads) == 2 * rank
new_sympy_shape = [
- d + pad_up + pad_down for d, pad_up, pad_down in
- zip(sympy_shape, pads[:rank], pads[rank:])
+ d + pad_up + pad_down
+ for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[
+ rank:])
]
self._update_computed_dims(new_sympy_shape)
else:
@@ -1590,8 +1596,8 @@ class SymbolicShapeInference:
scales = list(scales)
new_sympy_shape = [
sympy.simplify(sympy.floor(d * (end - start) * scale))
- for d, start, end, scale in
- zip(input_sympy_shape, roi_start, roi_end, scales)
+ for d, start, end, scale in zip(input_sympy_shape,
+ roi_start, roi_end, scales)
]
self._update_computed_dims(new_sympy_shape)
else:
@@ -2204,8 +2210,9 @@ class SymbolicShapeInference:
# topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
sorted_nodes = []
sorted_known_vi = set([
- i.name for i in list(self.out_mp_.graph.input) +
- list(self.out_mp_.graph.initializer)
+ i.name
+ for i in list(self.out_mp_.graph.input) + list(
+ self.out_mp_.graph.initializer)
])
if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]):
# Loop/Scan will have some graph output in graph inputs, so don't do topological sort