diff --git a/demos/audio_searching/src/operations/load.py b/demos/audio_searching/src/operations/load.py index 0d9edb78..d1ea0057 100644 --- a/demos/audio_searching/src/operations/load.py +++ b/demos/audio_searching/src/operations/load.py @@ -26,8 +26,9 @@ def get_audios(path): """ supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] return [ - item for sublist in [[os.path.join(dir, file) for file in files] - for dir, _, files in list(os.walk(path))] + item + for sublist in [[os.path.join(dir, file) for file in files] + for dir, _, files in list(os.walk(path))] for item in sublist if os.path.splitext(item)[1] in supported_formats ] diff --git a/demos/speech_web/API.md b/demos/speech_web/API.md index c5144674..f66ec138 100644 --- a/demos/speech_web/API.md +++ b/demos/speech_web/API.md @@ -401,4 +401,4 @@ curl -X 'GET' \ "code": 0, "result":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", "message": "ok" -``` \ No newline at end of file +``` diff --git a/demos/speech_web/speech_server/main.py b/demos/speech_web/speech_server/main.py index b1017667..d4750d59 100644 --- a/demos/speech_web/speech_server/main.py +++ b/demos/speech_web/speech_server/main.py @@ -3,48 +3,48 @@ # 2. 接收录音音频,返回识别结果 # 3. 接收ASR识别结果,返回NLP对话结果 # 4. 接收NLP对话结果,返回TTS音频 - +import argparse import base64 -import yaml -import os -import json import datetime +import json +import os +from typing import List + +import aiofiles import librosa import soundfile as sf -import numpy as np -import argparse import uvicorn -import aiofiles -from typing import Optional, List -from pydantic import BaseModel -from fastapi import FastAPI, Header, File, UploadFile, Form, Cookie, WebSocket, WebSocketDisconnect +from fastapi import FastAPI +from fastapi import File +from fastapi import Form +from fastapi import UploadFile +from fastapi import WebSocket +from fastapi import WebSocketDisconnect from fastapi.responses import StreamingResponse -from starlette.responses import FileResponse -from starlette.middleware.cors import CORSMiddleware -from starlette.requests import Request -from starlette.websockets import WebSocketState as WebSocketState - +from pydantic import BaseModel from src.AudioManeger import AudioMannger -from src.util import * from src.robot import Robot -from src.WebsocketManeger import ConnectionManager from src.SpeechBase.vpr import VPR +from src.util import * +from src.WebsocketManeger import ConnectionManager +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import FileResponse +from starlette.websockets import WebSocketState as WebSocketState from paddlespeech.server.engine.asr.online.python.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.utils.audio_process import float2pcm - # 解析配置 -parser = argparse.ArgumentParser( - prog='PaddleSpeechDemo', add_help=True) +parser = argparse.ArgumentParser(prog='PaddleSpeechDemo', add_help=True) parser.add_argument( - "--port", - action="store", - type=int, - help="port of the app", - default=8010, - required=False) + "--port", + action="store", + type=int, + help="port of the app", + default=8010, + required=False) args = parser.parse_args() port = args.port @@ -60,39 +60,41 @@ ie_model_path = "source/model" UPLOAD_PATH = "source/vpr" WAV_PATH = "source/wav" - -base_sources = [ - UPLOAD_PATH, WAV_PATH -] +base_sources = [UPLOAD_PATH, WAV_PATH] for path in base_sources: os.makedirs(path, exist_ok=True) - # 初始化 app = FastAPI() -chatbot = Robot(asr_config, tts_config, asr_init_path, ie_model_path=ie_model_path) +chatbot = Robot( + asr_config, tts_config, asr_init_path, ie_model_path=ie_model_path) manager = ConnectionManager() aumanager = AudioMannger(chatbot) aumanager.init() -vpr = VPR(db_path, dim = 192, top_k = 5) +vpr = VPR(db_path, dim=192, top_k=5) + # 服务配置 class NlpBase(BaseModel): chat: str + class TtsBase(BaseModel): - text: str + text: str + class Audios: def __init__(self) -> None: self.audios = b"" + audios = Audios() ###################################################################### ########################### ASR 服务 ################################# ##################################################################### + # 接收文件,返回ASR结果 # 上传文件 @app.post("/asr/offline") @@ -101,7 +103,8 @@ async def speech2textOffline(files: List[UploadFile]): asr_res = "" for file in files[:1]: # 生成时间戳 - now_name = "asr_offline_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + now_name = "asr_offline_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" out_file_path = os.path.join(WAV_PATH, now_name) async with aiofiles.open(out_file_path, 'wb') as out_file: content = await file.read() # async read @@ -110,10 +113,9 @@ async def speech2textOffline(files: List[UploadFile]): # 返回ASR识别结果 asr_res = chatbot.speech2text(out_file_path) return SuccessRequest(result=asr_res) - # else: - # return ErrorRequest(message="文件不是.wav格式") return ErrorRequest(message="上传文件为空") + # 接收文件,同时将wav强制转成16k, int16类型 @app.post("/asr/offlinefile") async def speech2textOfflineFile(files: List[UploadFile]): @@ -121,7 +123,8 @@ async def speech2textOfflineFile(files: List[UploadFile]): asr_res = "" for file in files[:1]: # 生成时间戳 - now_name = "asr_offline_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + now_name = "asr_offline_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" out_file_path = os.path.join(WAV_PATH, now_name) async with aiofiles.open(out_file_path, 'wb') as out_file: content = await file.read() # async read @@ -132,22 +135,18 @@ async def speech2textOfflineFile(files: List[UploadFile]): wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes wav_base64 = base64.b64encode(wav_bytes).decode('utf8') - + # 将文件重新写入 now_name = now_name[:-4] + "_16k" + ".wav" out_file_path = os.path.join(WAV_PATH, now_name) - sf.write(out_file_path,wav,16000) + sf.write(out_file_path, wav, 16000) # 返回ASR识别结果 asr_res = chatbot.speech2text(out_file_path) - response_res = { - "asr_result": asr_res, - "wav_base64": wav_base64 - } + response_res = {"asr_result": asr_res, "wav_base64": wav_base64} return SuccessRequest(result=response_res) - - return ErrorRequest(message="上传文件为空") + return ErrorRequest(message="上传文件为空") # 流式接收测试 @@ -161,15 +160,17 @@ async def speech2textOnlineRecive(files: List[UploadFile]): print(f"audios长度变化: {len(audios.audios)}") return SuccessRequest(message="接收成功") + # 采集环境噪音大小 @app.post("/asr/collectEnv") async def collectEnv(files: List[UploadFile]): - for file in files[:1]: + for file in files[:1]: content = await file.read() # async read # 初始化, wav 前44字节是头部信息 aumanager.compute_env_volume(content[44:]) vad_ = aumanager.vad_threshold - return SuccessRequest(result=vad_,message="采集环境噪音成功") + return SuccessRequest(result=vad_, message="采集环境噪音成功") + # 停止录音 @app.get("/asr/stopRecord") @@ -179,6 +180,7 @@ async def stopRecord(): print("Online录音暂停") return SuccessRequest(message="停止成功") + # 恢复录音 @app.get("/asr/resumeRecord") async def resumeRecord(): @@ -187,7 +189,7 @@ async def resumeRecord(): return SuccessRequest(message="Online录音恢复") -# 聊天用的ASR +# 聊天用的 ASR @app.websocket("/ws/asr/offlineStream") async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) @@ -210,9 +212,9 @@ async def websocket_endpoint(websocket: WebSocket): # print(f"用户-{user}-离开") -# Online识别的ASR + # 流式识别的 ASR @app.websocket('/ws/asr/onlineStream') -async def websocket_endpoint(websocket: WebSocket): +async def websocket_endpoint_online(websocket: WebSocket): """PaddleSpeech Online ASR Server api Args: @@ -298,12 +300,14 @@ async def websocket_endpoint(websocket: WebSocket): except WebSocketDisconnect: pass + ###################################################################### ########################### NLP 服务 ################################# ##################################################################### + @app.post("/nlp/chat") -async def chatOffline(nlp_base:NlpBase): +async def chatOffline(nlp_base: NlpBase): chat = nlp_base.chat if not chat: return ErrorRequest(message="传入文本为空") @@ -311,8 +315,9 @@ async def chatOffline(nlp_base:NlpBase): res = chatbot.chat(chat) return SuccessRequest(result=res) + @app.post("/nlp/ie") -async def ieOffline(nlp_base:NlpBase): +async def ieOffline(nlp_base: NlpBase): nlp_text = nlp_base.chat if not nlp_text: return ErrorRequest(message="传入文本为空") @@ -320,17 +325,20 @@ async def ieOffline(nlp_base:NlpBase): res = chatbot.ie(nlp_text) return SuccessRequest(result=res) + ###################################################################### ########################### TTS 服务 ################################# ##################################################################### + @app.post("/tts/offline") -async def text2speechOffline(tts_base:TtsBase): +async def text2speechOffline(tts_base: TtsBase): text = tts_base.text if not text: return ErrorRequest(message="文本为空") else: - now_name = "tts_"+ datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + now_name = "tts_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" out_file_path = os.path.join(WAV_PATH, now_name) # 保存为文件,再转成base64传输 chatbot.text2speech(text, outpath=out_file_path) @@ -339,12 +347,14 @@ async def text2speechOffline(tts_base:TtsBase): base_str = base64.b64encode(data_bin) return SuccessRequest(result=base_str) + # http流式TTS @app.post("/tts/online") async def stream_tts(request_body: TtsBase): text = request_body.text return StreamingResponse(chatbot.text2speechStreamBytes(text=text)) + # ws流式TTS @app.websocket("/ws/tts/online") async def stream_ttsWS(websocket: WebSocket): @@ -356,17 +366,11 @@ async def stream_ttsWS(websocket: WebSocket): if text: for sub_wav in chatbot.text2speechStream(text=text): # print("发送sub wav: ", len(sub_wav)) - res = { - "wav": sub_wav, - "done": False - } + res = {"wav": sub_wav, "done": False} await websocket.send_json(res) - + # 输送结束 - res = { - "wav": sub_wav, - "done": True - } + res = {"wav": sub_wav, "done": True} await websocket.send_json(res) # manager.disconnect(websocket) @@ -396,8 +400,9 @@ async def vpr_enroll(table_name: str=None, return {'status': False, 'msg': "spk_id can not be None"} # Save the upload data to server. content = await audio.read() - now_name = "vpr_enroll_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" - audio_path = os.path.join(UPLOAD_PATH, now_name) + now_name = "vpr_enroll_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + audio_path = os.path.join(UPLOAD_PATH, now_name) with open(audio_path, "wb+") as f: f.write(content) @@ -413,20 +418,19 @@ async def vpr_recog(request: Request, audio: UploadFile=File(...)): # Voice print recognition online # try: - # Save the upload data to server. + # Save the upload data to server. content = await audio.read() - now_name = "vpr_query_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" - query_audio_path = os.path.join(UPLOAD_PATH, now_name) + now_name = "vpr_query_" + datetime.datetime.strftime( + datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav" + query_audio_path = os.path.join(UPLOAD_PATH, now_name) with open(query_audio_path, "wb+") as f: - f.write(content) + f.write(content) spk_ids, paths, scores = vpr.do_search_vpr(query_audio_path) res = dict(zip(spk_ids, zip(paths, scores))) # Sort results by distance metric, closest distances first res = sorted(res.items(), key=lambda item: item[1][1], reverse=True) return res - # except Exception as e: - # return {'status': False, 'msg': e}, 400 @app.post('/vpr/del') @@ -460,17 +464,18 @@ async def vpr_database64(vprId: int): return {'status': False, 'msg': "vpr_id can not be None"} audio_path = vpr.do_get_wav(vprId) # 返回base64 - + # 将文件转成16k, 16bit类型的wav文件 wav, sr = librosa.load(audio_path, sr=16000) wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes wav_base64 = base64.b64encode(wav_bytes).decode('utf8') - + return SuccessRequest(result=wav_base64) except Exception as e: return {'status': False, 'msg': e}, 400 + @app.get('/vpr/data') async def vpr_data(vprId: int): # Get the audio file from path by spk_id in MySQL @@ -482,11 +487,6 @@ async def vpr_data(vprId: int): except Exception as e: return {'status': False, 'msg': e}, 400 + if __name__ == '__main__': uvicorn.run(app=app, host='0.0.0.0', port=port) - - - - - - diff --git a/demos/speech_web/speech_server/requirements.txt b/demos/speech_web/speech_server/requirements.txt index 7e7bd168..607f0d4d 100644 --- a/demos/speech_web/speech_server/requirements.txt +++ b/demos/speech_web/speech_server/requirements.txt @@ -1,14 +1,13 @@ aiofiles +faiss-cpu fastapi librosa numpy +paddlenlp +paddlepaddle +paddlespeech pydantic -scikit_learn +python-multipartscikit_learn SoundFile starlette uvicorn -paddlepaddle -paddlespeech -paddlenlp -faiss-cpu -python-multipart \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/AudioManeger.py b/demos/speech_web/speech_server/src/AudioManeger.py index 0deb0369..8fe512cf 100644 --- a/demos/speech_web/speech_server/src/AudioManeger.py +++ b/demos/speech_web/speech_server/src/AudioManeger.py @@ -1,15 +1,19 @@ -import imp -from queue import Queue -import numpy as np +import datetime import os import wave -import random -import datetime + +import numpy as np + from .util import randName class AudioMannger: - def __init__(self, robot, frame_length=160, frame=10, data_width=2, vad_default = 300): + def __init__(self, + robot, + frame_length=160, + frame=10, + data_width=2, + vad_default=300): # 二进制 pcm 流 self.audios = b'' self.asr_result = "" @@ -20,8 +24,9 @@ class AudioMannger: os.makedirs(self.file_dir, exist_ok=True) self.vad_deafult = vad_default self.vad_threshold = vad_default - self.vad_threshold_path = os.path.join(self.file_dir, "vad_threshold.npy") - + self.vad_threshold_path = os.path.join(self.file_dir, + "vad_threshold.npy") + # 10ms 一帧 self.frame_length = frame_length # 10帧,检测一次 vad @@ -30,67 +35,64 @@ class AudioMannger: self.data_width = data_width # window self.window_length = frame_length * frame * data_width - + # 是否开始录音 self.on_asr = False - self.silence_cnt = 0 + self.silence_cnt = 0 self.max_silence_cnt = 4 self.is_pause = False # 录音暂停与恢复 - - - + def init(self): if os.path.exists(self.vad_threshold_path): # 平均响度文件存在 self.vad_threshold = np.load(self.vad_threshold_path) - - + def clear_audio(self): # 清空 pcm 累积片段与 asr 识别结果 self.audios = b'' - + def clear_asr(self): self.asr_result = "" - - + def compute_chunk_volume(self, start_index, pcm_bins): # 根据帧长计算能量平均值 - pcm_bin = pcm_bins[start_index: start_index + self.window_length] + pcm_bin = pcm_bins[start_index:start_index + self.window_length] # 转成 numpy pcm_np = np.frombuffer(pcm_bin, np.int16) # 归一化 + 计算响度 x = pcm_np.astype(np.float32) x = np.abs(x) - return np.mean(x) - - + return np.mean(x) + def is_speech(self, start_index, pcm_bins): # 检查是否没 if start_index > len(pcm_bins): return False # 检查从这个 start 开始是否为静音帧 - energy = self.compute_chunk_volume(start_index=start_index, pcm_bins=pcm_bins) + energy = self.compute_chunk_volume( + start_index=start_index, pcm_bins=pcm_bins) # print(energy) if energy > self.vad_threshold: return True else: return False - + def compute_env_volume(self, pcm_bins): max_energy = 0 start = 0 while start < len(pcm_bins): - energy = self.compute_chunk_volume(start_index=start, pcm_bins=pcm_bins) + energy = self.compute_chunk_volume( + start_index=start, pcm_bins=pcm_bins) if energy > max_energy: max_energy = energy start += self.window_length self.vad_threshold = max_energy + 100 if max_energy > self.vad_deafult else self.vad_deafult - + # 保存成文件 np.save(self.vad_threshold_path, self.vad_threshold) print(f"vad 阈值大小: {self.vad_threshold}") print(f"环境采样保存: {os.path.realpath(self.vad_threshold_path)}") - + def stream_asr(self, pcm_bin): # 先把 pcm_bin 送进去做端点检测 start = 0 @@ -99,7 +101,7 @@ class AudioMannger: self.on_asr = True self.silence_cnt = 0 print("录音中") - self.audios += pcm_bin[ start : start + self.window_length] + self.audios += pcm_bin[start:start + self.window_length] else: if self.on_asr: self.silence_cnt += 1 @@ -110,41 +112,42 @@ class AudioMannger: print("录音停止") # audios 保存为 wav, 送入 ASR if len(self.audios) > 2 * 16000: - file_path = os.path.join(self.file_dir, "asr_" + datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d%H%M%S') + randName() + ".wav") + file_path = os.path.join( + self.file_dir, + "asr_" + datetime.datetime.strftime( + datetime.datetime.now(), + '%Y%m%d%H%M%S') + randName() + ".wav") self.save_audio(file_path=file_path) self.asr_result = self.robot.speech2text(file_path) self.clear_audio() - return self.asr_result + return self.asr_result else: # 正常接收 print("录音中 静音") - self.audios += pcm_bin[ start : start + self.window_length] + self.audios += pcm_bin[start:start + self.window_length] start += self.window_length return "" - + def save_audio(self, file_path): print("保存音频") - wf = wave.open(file_path, 'wb') # 创建一个音频文件,名字为“01.wav" - wf.setnchannels(1) # 设置声道数为2 - wf.setsampwidth(2) # 设置采样深度为 - wf.setframerate(16000) # 设置采样率为16000 + wf = wave.open(file_path, 'wb') # 创建一个音频文件,名字为“01.wav" + wf.setnchannels(1) # 设置声道数为2 + wf.setsampwidth(2) # 设置采样深度为 + wf.setframerate(16000) # 设置采样率为16000 # 将数据写入创建的音频文件 wf.writeframes(self.audios) # 写完后将文件关闭 wf.close() - + def end(self): # audios 保存为 wav, 送入 ASR file_path = os.path.join(self.file_dir, "asr.wav") self.save_audio(file_path=file_path) return self.robot.speech2text(file_path) - + def stop(self): self.is_pause = True self.audios = b'' - + def resume(self): self.is_pause = False - - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/asr.py b/demos/speech_web/speech_server/src/SpeechBase/asr.py index 8d4c0cff..5213ea78 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/asr.py +++ b/demos/speech_web/speech_server/src/SpeechBase/asr.py @@ -1,13 +1,10 @@ -from re import sub import numpy as np -import paddle -import librosa -import soundfile from paddlespeech.server.engine.asr.online.python.asr_engine import ASREngine from paddlespeech.server.engine.asr.online.python.asr_engine import PaddleASRConnectionHanddler from paddlespeech.server.utils.config import get_config + def readWave(samples): x_len = len(samples) @@ -31,20 +28,23 @@ def readWave(samples): class ASR: - def __init__(self, config_path, ) -> None: + def __init__( + self, + config_path, ) -> None: self.config = get_config(config_path)['asr_online'] self.engine = ASREngine() self.engine.init(self.config) self.connection_handler = PaddleASRConnectionHanddler(self.engine) - + def offlineASR(self, samples, sample_rate=16000): - x_chunk, x_chunk_lens = self.engine.preprocess(samples=samples, sample_rate=sample_rate) + x_chunk, x_chunk_lens = self.engine.preprocess( + samples=samples, sample_rate=sample_rate) self.engine.run(x_chunk, x_chunk_lens) result = self.engine.postprocess() self.engine.reset() return result - def onlineASR(self, samples:bytes=None, is_finished=False): + def onlineASR(self, samples: bytes=None, is_finished=False): if not is_finished: # 流式开始 self.connection_handler.extract_feat(samples) @@ -58,5 +58,3 @@ class ASR: asr_results = self.connection_handler.get_result() self.connection_handler.reset() return asr_results - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/nlp.py b/demos/speech_web/speech_server/src/SpeechBase/nlp.py index 4ece6325..b642a51d 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/nlp.py +++ b/demos/speech_web/speech_server/src/SpeechBase/nlp.py @@ -1,23 +1,23 @@ from paddlenlp import Taskflow + class NLP: def __init__(self, ie_model_path=None): schema = ["时间", "出发地", "目的地", "费用"] if ie_model_path: - self.ie_model = Taskflow("information_extraction", - schema=schema, task_path=ie_model_path) + self.ie_model = Taskflow( + "information_extraction", + schema=schema, + task_path=ie_model_path) else: - self.ie_model = Taskflow("information_extraction", - schema=schema) - + self.ie_model = Taskflow("information_extraction", schema=schema) + self.dialogue_model = Taskflow("dialogue") - + def chat(self, text): result = self.dialogue_model([text]) return result[0] - + def ie(self, text): result = self.ie_model(text) return result - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py b/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py index 6937def5..bd8d5897 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py +++ b/demos/speech_web/speech_server/src/SpeechBase/sql_helper.py @@ -1,18 +1,19 @@ import base64 -import sqlite3 import os +import sqlite3 + import numpy as np -from pkg_resources import resource_stream -def dict_factory(cursor, row): - d = {} - for idx, col in enumerate(cursor.description): - d[col[0]] = row[idx] - return d +def dict_factory(cursor, row): + d = {} + for idx, col in enumerate(cursor.description): + d[col[0]] = row[idx] + return d + class DataBase(object): - def __init__(self, db_path:str): + def __init__(self, db_path: str): db_path = os.path.realpath(db_path) if os.path.exists(db_path): @@ -21,12 +22,12 @@ class DataBase(object): db_path_dir = os.path.dirname(db_path) os.makedirs(db_path_dir, exist_ok=True) self.db_path = db_path - + self.conn = sqlite3.connect(self.db_path) self.conn.row_factory = dict_factory self.cursor = self.conn.cursor() self.init_database() - + def init_database(self): """ 初始化数据库, 若表不存在则创建 @@ -41,20 +42,21 @@ class DataBase(object): """ self.cursor.execute(sql) self.conn.commit() - + def execute_base(self, sql, data_dict): self.cursor.execute(sql, data_dict) self.conn.commit() - - def insert_one(self, username, vector_base64:str, wav_path): + + def insert_one(self, username, vector_base64: str, wav_path): if not os.path.exists(wav_path): return None, "wav not exists" else: - sql = f""" + sql = """ insert into vprtable (username, vector, wavpath) values (?, ?, ?) """ + try: self.cursor.execute(sql, (username, vector_base64, wav_path)) self.conn.commit() @@ -63,25 +65,27 @@ class DataBase(object): except Exception as e: print(e) return None, e - + def select_all(self): sql = """ SELECT * from vprtable """ result = self.cursor.execute(sql).fetchall() return result - + def select_by_id(self, vpr_id): sql = f""" SELECT * from vprtable WHERE `id` = {vpr_id} """ + result = self.cursor.execute(sql).fetchall() return result - + def select_by_username(self, username): sql = f""" SELECT * from vprtable WHERE `username` = '{username}' """ + result = self.cursor.execute(sql).fetchall() return result @@ -89,28 +93,30 @@ class DataBase(object): sql = f""" DELETE from vprtable WHERE `username`='{username}' """ + self.cursor.execute(sql) self.conn.commit() - + def drop_all(self): - sql = f""" + sql = """ DELETE from vprtable """ + self.cursor.execute(sql) self.conn.commit() - + def drop_table(self): - sql = f""" + sql = """ DROP TABLE vprtable """ + self.cursor.execute(sql) self.conn.commit() - - def encode_vector(self, vector:np.ndarray): + + def encode_vector(self, vector: np.ndarray): return base64.b64encode(vector).decode('utf8') - + def decode_vector(self, vector_base64, dtype=np.float32): b = base64.b64decode(vector_base64) vc = np.frombuffer(b, dtype=dtype) return vc - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/tts.py b/demos/speech_web/speech_server/src/SpeechBase/tts.py index d5ba0c80..eb32bca0 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/tts.py +++ b/demos/speech_web/speech_server/src/SpeechBase/tts.py @@ -5,18 +5,19 @@ # 2. 加载模型 # 3. 端到端推理 # 4. 流式推理 - import base64 -import math import logging +import math + import numpy as np -from paddlespeech.server.utils.onnx_infer import get_sess -from paddlespeech.t2s.frontend.zh_frontend import Frontend -from paddlespeech.server.utils.util import denorm, get_chunks + +from paddlespeech.server.engine.tts.online.onnx.tts_engine import TTSEngine from paddlespeech.server.utils.audio_process import float2pcm from paddlespeech.server.utils.config import get_config +from paddlespeech.server.utils.util import denorm +from paddlespeech.server.utils.util import get_chunks +from paddlespeech.t2s.frontend.zh_frontend import Frontend -from paddlespeech.server.engine.tts.online.onnx.tts_engine import TTSEngine class TTS: def __init__(self, config_path): @@ -26,12 +27,12 @@ class TTS: self.engine.init(self.config) self.executor = self.engine.executor #self.engine.warm_up() - + # 前端初始化 self.frontend = Frontend( - phone_vocab_path=self.engine.executor.phones_dict, - tone_vocab_path=None) - + phone_vocab_path=self.engine.executor.phones_dict, + tone_vocab_path=None) + def depadding(self, data, chunk_num, chunk_id, block, pad, upsample): """ Streaming inference removes the result of pad inference @@ -48,39 +49,37 @@ class TTS: data = data[front_pad * upsample:(front_pad + block) * upsample] return data - + def offlineTTS(self, text): get_tone_ids = False merge_sentences = False - + input_ids = self.frontend.get_input_ids( - text, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) + text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] wav_list = [] for i in range(len(phone_ids)): orig_hs = self.engine.executor.am_encoder_infer_sess.run( - None, input_feed={'text': phone_ids[i].numpy()} - ) + None, input_feed={'text': phone_ids[i].numpy()}) hs = orig_hs[0] am_decoder_output = self.engine.executor.am_decoder_sess.run( - None, input_feed={'xs': hs}) + None, input_feed={'xs': hs}) am_postnet_output = self.engine.executor.am_postnet_sess.run( - None, - input_feed={ - 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) - }) + None, + input_feed={ + 'xs': np.transpose(am_decoder_output[0], (0, 2, 1)) + }) am_output_data = am_decoder_output + np.transpose( am_postnet_output[0], (0, 2, 1)) normalized_mel = am_output_data[0][0] - mel = denorm(normalized_mel, self.engine.executor.am_mu, self.engine.executor.am_std) + mel = denorm(normalized_mel, self.engine.executor.am_mu, + self.engine.executor.am_std) wav = self.engine.executor.voc_sess.run( - output_names=None, input_feed={'logmel': mel})[0] + output_names=None, input_feed={'logmel': mel})[0] wav_list.append(wav) wavs = np.concatenate(wav_list) return wavs - + def streamTTS(self, text): get_tone_ids = False @@ -88,9 +87,7 @@ class TTS: # front input_ids = self.frontend.get_input_ids( - text, - merge_sentences=merge_sentences, - get_tone_ids=get_tone_ids) + text, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] for i in range(len(phone_ids)): @@ -105,14 +102,15 @@ class TTS: mel = mel[0] # voc streaming - mel_chunks = get_chunks(mel, self.config.voc_block, self.config.voc_pad, "voc") + mel_chunks = get_chunks(mel, self.config.voc_block, + self.config.voc_pad, "voc") voc_chunk_num = len(mel_chunks) for i, mel_chunk in enumerate(mel_chunks): sub_wav = self.executor.voc_sess.run( output_names=None, input_feed={'logmel': mel_chunk}) - sub_wav = self.depadding(sub_wav[0], voc_chunk_num, i, - self.config.voc_block, self.config.voc_pad, - self.config.voc_upsample) + sub_wav = self.depadding( + sub_wav[0], voc_chunk_num, i, self.config.voc_block, + self.config.voc_pad, self.config.voc_upsample) yield self.after_process(sub_wav) @@ -130,7 +128,8 @@ class TTS: end = min(self.config.voc_block + self.config.voc_pad, mel_len) # streaming am - hss = get_chunks(orig_hs, self.config.am_block, self.config.am_pad, "am") + hss = get_chunks(orig_hs, self.config.am_block, + self.config.am_pad, "am") am_chunk_num = len(hss) for i, hs in enumerate(hss): am_decoder_output = self.executor.am_decoder_sess.run( @@ -147,7 +146,8 @@ class TTS: sub_mel = denorm(normalized_mel, self.executor.am_mu, self.executor.am_std) sub_mel = self.depadding(sub_mel, am_chunk_num, i, - self.config.am_block, self.config.am_pad, 1) + self.config.am_block, + self.config.am_pad, 1) if i == 0: mel_streaming = sub_mel @@ -165,23 +165,22 @@ class TTS: output_names=None, input_feed={'logmel': voc_chunk}) sub_wav = self.depadding( sub_wav[0], voc_chunk_num, voc_chunk_id, - self.config.voc_block, self.config.voc_pad, self.config.voc_upsample) + self.config.voc_block, self.config.voc_pad, + self.config.voc_upsample) yield self.after_process(sub_wav) voc_chunk_id += 1 - start = max( - 0, voc_chunk_id * self.config.voc_block - self.config.voc_pad) - end = min( - (voc_chunk_id + 1) * self.config.voc_block + self.config.voc_pad, - mel_len) + start = max(0, voc_chunk_id * self.config.voc_block - + self.config.voc_pad) + end = min((voc_chunk_id + 1) * self.config.voc_block + + self.config.voc_pad, mel_len) else: logging.error( "Only support fastspeech2_csmsc or fastspeech2_cnndecoder_csmsc on streaming tts." - ) + ) - def streamTTSBytes(self, text): for wav in self.engine.executor.infer( text=text, @@ -191,19 +190,14 @@ class TTS: wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes yield wav_bytes - - + def after_process(self, wav): # for tvm wav = float2pcm(wav) # float32 to int16 wav_bytes = wav.tobytes() # to bytes wav_base64 = base64.b64encode(wav_bytes).decode('utf8') # to base64 return wav_base64 - + def streamTTS_TVM(self, text): # 用 TVM 优化 pass - - - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/SpeechBase/vpr.py b/demos/speech_web/speech_server/src/SpeechBase/vpr.py index 29ee986e..cf336799 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/vpr.py +++ b/demos/speech_web/speech_server/src/SpeechBase/vpr.py @@ -1,11 +1,13 @@ # vpr Demo 没有使用 mysql 与 muilvs, 仅用于docker演示 import logging + import faiss -from matplotlib import use import numpy as np + from .sql_helper import DataBase from .vpr_encode import get_audio_embedding + class VPR: def __init__(self, db_path, dim, top_k) -> None: # 初始化 @@ -14,15 +16,15 @@ class VPR: self.top_k = top_k self.dtype = np.float32 self.vpr_idx = 0 - + # db 初始化 self.db = DataBase(db_path) - + # faiss 初始化 index_ip = faiss.IndexFlatIP(dim) self.index_ip = faiss.IndexIDMap(index_ip) self.init() - + def init(self): # demo 初始化,把 mysql中的向量注册到 faiss 中 sql_dbs = self.db.select_all() @@ -34,12 +36,13 @@ class VPR: if len(vc.shape) == 1: vc = np.expand_dims(vc, axis=0) # 构建数据库 - self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64')) + self.index_ip.add_with_ids(vc, np.array( + (idx, )).astype('int64')) logging.info("faiss 构建完毕") - + def faiss_enroll(self, idx, vc): - self.index_ip.add_with_ids(vc, np.array((idx,)).astype('int64')) - + self.index_ip.add_with_ids(vc, np.array((idx, )).astype('int64')) + def vpr_enroll(self, username, wav_path): # 注册声纹 emb = get_audio_embedding(wav_path) @@ -53,21 +56,22 @@ class VPR: else: last_idx, mess = None return last_idx - + def vpr_recog(self, wav_path): # 识别声纹 emb_search = get_audio_embedding(wav_path) - + if emb_search is not None: emb_search = np.expand_dims(emb_search, axis=0) D, I = self.index_ip.search(emb_search, self.top_k) D = D.tolist()[0] - I = I.tolist()[0] - return [(round(D[i] * 100, 2 ), I[i]) for i in range(len(D)) if I[i] != -1] + I = I.tolist()[0] + return [(round(D[i] * 100, 2), I[i]) for i in range(len(D)) + if I[i] != -1] else: logging.error("识别失败") return None - + def do_search_vpr(self, wav_path): spk_ids, paths, scores = [], [], [] recog_result = self.vpr_recog(wav_path) @@ -78,41 +82,39 @@ class VPR: scores.append(score) paths.append("") return spk_ids, paths, scores - + def vpr_del(self, username): # 根据用户username, 删除声纹 # 查用户ID,删除对应向量 res = self.db.select_by_username(username) for r in res: idx = r['id'] - self.index_ip.remove_ids(np.array((idx,)).astype('int64')) - + self.index_ip.remove_ids(np.array((idx, )).astype('int64')) + self.db.drop_by_username(username) - + def vpr_list(self): # 获取数据列表 return self.db.select_all() - + def do_list(self): spk_ids, vpr_ids = [], [] for res in self.db.select_all(): spk_ids.append(res['username']) vpr_ids.append(res['id']) - return spk_ids, vpr_ids - + return spk_ids, vpr_ids + def do_get_wav(self, vpr_idx): - res = self.db.select_by_id(vpr_idx) - return res[0]['wavpath'] - - + res = self.db.select_by_id(vpr_idx) + return res[0]['wavpath'] + def vpr_data(self, idx): # 获取对应ID的数据 res = self.db.select_by_id(idx) return res - + def vpr_droptable(self): # 删除表 self.db.drop_table() # 清空 faiss self.index_ip.reset() - diff --git a/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py b/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py index a6a00e4d..9d052fd9 100644 --- a/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py +++ b/demos/speech_web/speech_server/src/SpeechBase/vpr_encode.py @@ -1,9 +1,12 @@ -from paddlespeech.cli.vector import VectorExecutor -import numpy as np import logging +import numpy as np + +from paddlespeech.cli.vector import VectorExecutor + vector_executor = VectorExecutor() + def get_audio_embedding(path): """ Use vpr_inference to generate embedding of audio @@ -16,5 +19,3 @@ def get_audio_embedding(path): except Exception as e: logging.error(f"Error with embedding:{e}") return None - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/WebsocketManeger.py b/demos/speech_web/speech_server/src/WebsocketManeger.py index 5edde843..954d849a 100644 --- a/demos/speech_web/speech_server/src/WebsocketManeger.py +++ b/demos/speech_web/speech_server/src/WebsocketManeger.py @@ -2,6 +2,7 @@ from typing import List from fastapi import WebSocket + class ConnectionManager: def __init__(self): # 存放激活的ws连接对象 @@ -28,4 +29,4 @@ class ConnectionManager: await connection.send_text(message) -manager = ConnectionManager() \ No newline at end of file +manager = ConnectionManager() diff --git a/demos/speech_web/speech_server/src/robot.py b/demos/speech_web/speech_server/src/robot.py index b971c57b..dd8c56e0 100644 --- a/demos/speech_web/speech_server/src/robot.py +++ b/demos/speech_web/speech_server/src/robot.py @@ -1,60 +1,64 @@ -from paddlespeech.cli.asr.infer import ASRExecutor -import soundfile as sf import os -import librosa +import soundfile as sf from src.SpeechBase.asr import ASR -from src.SpeechBase.tts import TTS from src.SpeechBase.nlp import NLP +from src.SpeechBase.tts import TTS + +from paddlespeech.cli.asr.infer import ASRExecutor class Robot: - def __init__(self, asr_config, tts_config,asr_init_path, + def __init__(self, + asr_config, + tts_config, + asr_init_path, ie_model_path=None) -> None: self.nlp = NLP(ie_model_path=ie_model_path) self.asr = ASR(config_path=asr_config) self.tts = TTS(config_path=tts_config) self.tts_sample_rate = 24000 self.asr_sample_rate = 16000 - + # 流式识别效果不如端到端的模型,这里流式模型与端到端模型分开 self.asr_model = ASRExecutor() self.asr_name = "conformer_wenetspeech" self.warm_up_asrmodel(asr_init_path) - - def warm_up_asrmodel(self, asr_init_path): + def warm_up_asrmodel(self, asr_init_path): if not os.path.exists(asr_init_path): path_dir = os.path.dirname(asr_init_path) if not os.path.exists(path_dir): os.makedirs(path_dir, exist_ok=True) - + # TTS生成,采样率24000 text = "生成初始音频" self.text2speech(text, asr_init_path) - + # asr model初始化 - self.asr_model(asr_init_path, model=self.asr_name,lang='zh', - sample_rate=16000, force_yes=True) - - + self.asr_model( + asr_init_path, + model=self.asr_name, + lang='zh', + sample_rate=16000, + force_yes=True) + def speech2text(self, audio_file): self.asr_model.preprocess(self.asr_name, audio_file) self.asr_model.infer(self.asr_name) res = self.asr_model.postprocess() return res - + def text2speech(self, text, outpath): wav = self.tts.offlineTTS(text) - sf.write( - outpath, wav, samplerate=self.tts_sample_rate) + sf.write(outpath, wav, samplerate=self.tts_sample_rate) res = wav return res - + def text2speechStream(self, text): for sub_wav_base64 in self.tts.streamTTS(text=text): yield sub_wav_base64 - + def text2speechStreamBytes(self, text): for wav_bytes in self.tts.streamTTSBytes(text=text): yield wav_bytes @@ -66,5 +70,3 @@ class Robot: def ie(self, text): result = self.nlp.ie(text) return result - - \ No newline at end of file diff --git a/demos/speech_web/speech_server/src/util.py b/demos/speech_web/speech_server/src/util.py index 34005d91..4a566b6e 100644 --- a/demos/speech_web/speech_server/src/util.py +++ b/demos/speech_web/speech_server/src/util.py @@ -1,18 +1,13 @@ import random + def randName(n=5): - return "".join(random.sample('zyxwvutsrqponmlkjihgfedcba',n)) + return "".join(random.sample('zyxwvutsrqponmlkjihgfedcba', n)) + def SuccessRequest(result=None, message="ok"): - return { - "code": 0, - "result":result, - "message": message - } + return {"code": 0, "result": result, "message": message} + def ErrorRequest(result=None, message="error"): - return { - "code": -1, - "result":result, - "message": message - } \ No newline at end of file + return {"code": -1, "result": result, "message": message} diff --git a/demos/streaming_asr_server/local/rtf_from_log.py b/demos/streaming_asr_server/local/rtf_from_log.py index 4b89b48f..09a9c975 100755 --- a/demos/streaming_asr_server/local/rtf_from_log.py +++ b/demos/streaming_asr_server/local/rtf_from_log.py @@ -34,7 +34,7 @@ if __name__ == '__main__': n = 0 for m in rtfs: # not accurate, may have duplicate log - n += 1 + n += 1 T += m['T'] P += m['P'] diff --git a/docs/requirements.txt b/docs/requirements.txt index ee116a9b..11e94f48 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,12 +1,6 @@ -myst-parser -numpydoc -recommonmark>=0.5.0 -sphinx -sphinx-autobuild -sphinx-markdown-tables -sphinx_rtd_theme -paddlepaddle>=2.2.2 +braceexpandcolorlog editdistance +fastapi g2p_en g2pM h5py @@ -14,40 +8,45 @@ inflect jieba jsonlines kaldiio +keyboard librosa==0.8.1 loguru matplotlib +myst-parser nara_wpe +numpydoc onnxruntime==1.10.0 opencc -pandas paddlenlp +paddlepaddle>=2.2.2 paddlespeech_feat +pandas +pathos == 0.2.8 +pattern_singleton Pillow>=9.0.0 praatio==5.0.0 +prettytable pypinyin pypinyin-dict python-dateutil pyworld==0.2.12 +recommonmark>=0.5.0 resampy==0.2.2 sacrebleu scipy sentencepiece~=0.1.96 soundfile~=0.10 +sphinx +sphinx-autobuild +sphinx-markdown-tables +sphinx_rtd_theme textgrid timer tqdm typeguard +uvicorn visualdl webrtcvad +websockets yacs~=0.1.8 -prettytable zhon -colorlog -pathos == 0.2.8 -fastapi -websockets -keyboard -uvicorn -pattern_singleton -braceexpand \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index c94cf0b8..cd9b1807 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -20,10 +20,11 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. +import os +import sys + import recommonmark.parser import sphinx_rtd_theme -import sys -import os sys.path.insert(0, os.path.abspath('../..')) autodoc_mock_imports = ["soundfile", "librosa"] diff --git a/examples/iwslt2012/punc0/local/preprocess.py b/examples/iwslt2012/punc0/local/preprocess.py index 03b27e89..3df07c72 100644 --- a/examples/iwslt2012/punc0/local/preprocess.py +++ b/examples/iwslt2012/punc0/local/preprocess.py @@ -1,27 +1,29 @@ import argparse -import os + def process_sentence(line): - if line == '': return '' - res = line[0] - for i in range(1, len(line)): - res += (' ' + line[i]) - return res + if line == '': + return '' + res = line[0] + for i in range(1, len(line)): + res += (' ' + line[i]) + return res + if __name__ == "__main__": - paser = argparse.ArgumentParser(description = "Input filename") - paser.add_argument('-input_file') - paser.add_argument('-output_file') - sentence_cnt = 0 - args = paser.parse_args() - with open(args.input_file, 'r') as f: - with open(args.output_file, 'w') as write_f: - while True: - line = f.readline() - if line: - sentence_cnt += 1 - write_f.write(process_sentence(line)) - else: - break - print('preprocess over') - print('total sentences number:', sentence_cnt) + paser = argparse.ArgumentParser(description="Input filename") + paser.add_argument('-input_file') + paser.add_argument('-output_file') + sentence_cnt = 0 + args = paser.parse_args() + with open(args.input_file, 'r') as f: + with open(args.output_file, 'w') as write_f: + while True: + line = f.readline() + if line: + sentence_cnt += 1 + write_f.write(process_sentence(line)) + else: + break + print('preprocess over') + print('total sentences number:', sentence_cnt) diff --git a/examples/other/tts_finetune/tts3/finetune.py b/examples/other/tts_finetune/tts3/finetune.py index f05ba943..0f060b44 100644 --- a/examples/other/tts_finetune/tts3/finetune.py +++ b/examples/other/tts_finetune/tts3/finetune.py @@ -17,15 +17,14 @@ from pathlib import Path from typing import Union import yaml -from paddle import distributed as dist -from yacs.config import CfgNode - -from paddlespeech.t2s.exps.fastspeech2.train import train_sp - from local.check_oov import get_check_result from local.extract import extract_feature from local.label_process import get_single_label from local.prepare_env import generate_finetune_env +from paddle import distributed as dist +from yacs.config import CfgNode + +from paddlespeech.t2s.exps.fastspeech2.train import train_sp from utils.gen_duration_from_textgrid import gen_duration_from_textgrid DICT_EN = 'tools/aligner/cmudict-0.7b' diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py index 4b1c0ef3..b781c4a8 100644 --- a/paddlespeech/__init__.py +++ b/paddlespeech/__init__.py @@ -14,5 +14,3 @@ import _locale _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) - - diff --git a/paddlespeech/audio/__init__.py b/paddlespeech/audio/__init__.py index 83be8e32..a9195810 100644 --- a/paddlespeech/audio/__init__.py +++ b/paddlespeech/audio/__init__.py @@ -14,12 +14,12 @@ from . import compliance from . import datasets from . import features -from . import text -from . import transform -from . import streamdata from . import functional from . import io from . import metric from . import sox_effects +from . import streamdata +from . import text +from . import transform from .backends import load from .backends import save diff --git a/paddlespeech/audio/streamdata/__init__.py b/paddlespeech/audio/streamdata/__init__.py index 753fcc11..47a2e79b 100644 --- a/paddlespeech/audio/streamdata/__init__.py +++ b/paddlespeech/audio/streamdata/__init__.py @@ -4,67 +4,66 @@ # Modified from https://github.com/webdataset/webdataset # # flake8: noqa - -from .cache import ( - cached_tarfile_samples, - cached_tarfile_to_samples, - lru_cleanup, - pipe_cleaner, -) -from .compat import WebDataset, WebLoader, FluidWrapper -from .extradatasets import MockDataset, with_epoch, with_length -from .filters import ( - associate, - batched, - decode, - detshuffle, - extract_keys, - getfirst, - info, - map, - map_dict, - map_tuple, - pipelinefilter, - rename, - rename_keys, - audio_resample, - select, - shuffle, - slice, - to_tuple, - transform_with, - unbatched, - xdecode, - audio_data_filter, - audio_tokenize, - audio_resample, - audio_compute_fbank, - audio_spec_aug, - sort, - audio_padding, - audio_cmvn, - placeholder, -) -from .handlers import ( - ignore_and_continue, - ignore_and_stop, - reraise_exception, - warn_and_continue, - warn_and_stop, -) +from .cache import cached_tarfile_samples +from .cache import cached_tarfile_to_samples +from .cache import lru_cleanup +from .cache import pipe_cleaner +from .compat import FluidWrapper +from .compat import WebDataset +from .compat import WebLoader +from .extradatasets import MockDataset +from .extradatasets import with_epoch +from .extradatasets import with_length +from .filters import associate +from .filters import audio_cmvn +from .filters import audio_compute_fbank +from .filters import audio_data_filter +from .filters import audio_padding +from .filters import audio_resample +from .filters import audio_spec_aug +from .filters import audio_tokenize +from .filters import batched +from .filters import decode +from .filters import detshuffle +from .filters import extract_keys +from .filters import getfirst +from .filters import info +from .filters import map +from .filters import map_dict +from .filters import map_tuple +from .filters import pipelinefilter +from .filters import placeholder +from .filters import rename +from .filters import rename_keys +from .filters import select +from .filters import shuffle +from .filters import slice +from .filters import sort +from .filters import to_tuple +from .filters import transform_with +from .filters import unbatched +from .filters import xdecode +from .handlers import ignore_and_continue +from .handlers import ignore_and_stop +from .handlers import reraise_exception +from .handlers import warn_and_continue +from .handlers import warn_and_stop +from .mix import RandomMix +from .mix import RoundRobin from .pipeline import DataPipeline -from .shardlists import ( - MultiShardSample, - ResampledShards, - SimpleShardList, - non_empty, - resampled, - shardspec, - single_node_only, - split_by_node, - split_by_worker, -) -from .tariterators import tarfile_samples, tarfile_to_samples -from .utils import PipelineStage, repeatedly -from .writer import ShardWriter, TarWriter, numpy_dumps -from .mix import RandomMix, RoundRobin +from .shardlists import MultiShardSample +from .shardlists import non_empty +from .shardlists import resampled +from .shardlists import ResampledShards +from .shardlists import shardspec +from .shardlists import SimpleShardList +from .shardlists import single_node_only +from .shardlists import split_by_node +from .shardlists import split_by_worker +from .tariterators import tarfile_samples +from .tariterators import tarfile_to_samples +from .utils import PipelineStage +from .utils import repeatedly +from .writer import numpy_dumps +from .writer import ShardWriter +from .writer import TarWriter diff --git a/paddlespeech/audio/streamdata/autodecode.py b/paddlespeech/audio/streamdata/autodecode.py index ca0e2ea2..d7f7937b 100644 --- a/paddlespeech/audio/streamdata/autodecode.py +++ b/paddlespeech/audio/streamdata/autodecode.py @@ -5,18 +5,19 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - """Automatically decode webdataset samples.""" - -import io, json, os, pickle, re, tempfile +import io +import json +import os +import pickle +import re +import tempfile from functools import partial import numpy as np - """Extensions passed on to the image decoder.""" image_extensions = "jpg jpeg png ppm pgm pbm pnm".split() - ################################################################ # handle basic datatypes ################################################################ @@ -128,7 +129,7 @@ def call_extension_handler(key, data, f, extensions): target = target.split(".") if len(target) > len(extension): continue - if extension[-len(target) :] == target: + if extension[-len(target):] == target: return f(data) return None @@ -268,7 +269,6 @@ def imagehandler(imagespec, extensions=image_extensions): ################################################################ # torch video ################################################################ - ''' def torch_video(key, data): """Decode video using the torchvideo library. @@ -289,7 +289,6 @@ def torch_video(key, data): return torchvision.io.read_video(fname, pts_unit="sec") ''' - ################################################################ # paddlespeech.audio ################################################################ @@ -359,7 +358,6 @@ def gzfilter(key, data): # decode entire training amples ################################################################ - default_pre_handlers = [gzfilter] default_post_handlers = [basichandlers] @@ -387,7 +385,8 @@ class Decoder: pre = default_pre_handlers if post is None: post = default_post_handlers - assert all(callable(h) for h in handlers), f"one of {handlers} not callable" + assert all(callable(h) + for h in handlers), f"one of {handlers} not callable" assert all(callable(h) for h in pre), f"one of {pre} not callable" assert all(callable(h) for h in post), f"one of {post} not callable" self.handlers = pre + handlers + post diff --git a/paddlespeech/audio/streamdata/cache.py b/paddlespeech/audio/streamdata/cache.py index e7bbffa1..faa19639 100644 --- a/paddlespeech/audio/streamdata/cache.py +++ b/paddlespeech/audio/streamdata/cache.py @@ -2,7 +2,10 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset -import itertools, os, random, re, sys +import os +import random +import re +import sys from urllib.parse import urlparse from . import filters @@ -40,7 +43,7 @@ def lru_cleanup(cache_dir, cache_size, keyfn=os.path.getctime, verbose=False): os.remove(fname) -def download(url, dest, chunk_size=1024 ** 2, verbose=False): +def download(url, dest, chunk_size=1024**2, verbose=False): """Download a file from `url` to `dest`.""" temp = dest + f".temp{os.getpid()}" with gopen.gopen(url) as stream: @@ -65,12 +68,11 @@ def pipe_cleaner(spec): def get_file_cached( - spec, - cache_size=-1, - cache_dir=None, - url_to_name=pipe_cleaner, - verbose=False, -): + spec, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + verbose=False, ): if cache_size == -1: cache_size = default_cache_size if cache_dir is None: @@ -107,15 +109,14 @@ verbose_cache = int(os.environ.get("WDS_VERBOSE_CACHE", "0")) def cached_url_opener( - data, - handler=reraise_exception, - cache_size=-1, - cache_dir=None, - url_to_name=pipe_cleaner, - validator=check_tar_format, - verbose=False, - always=False, -): + data, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + url_to_name=pipe_cleaner, + validator=check_tar_format, + verbose=False, + always=False, ): """Given a stream of url names (packaged in `dict(url=url)`), yield opened streams.""" verbose = verbose or verbose_cache for sample in data: @@ -132,8 +133,7 @@ def cached_url_opener( cache_size=cache_size, cache_dir=cache_dir, url_to_name=url_to_name, - verbose=verbose, - ) + verbose=verbose, ) if verbose: print("# opening %s" % dest, file=sys.stderr) assert os.path.exists(dest) @@ -143,9 +143,8 @@ def cached_url_opener( data = f.read(200) os.remove(dest) raise ValueError( - "%s (%s) is not a tar archive, but a %s, contains %s" - % (dest, url, ftype, repr(data)) - ) + "%s (%s) is not a tar archive, but a %s, contains %s" % + (dest, url, ftype, repr(data))) try: stream = open(dest, "rb") sample.update(stream=stream) @@ -158,7 +157,7 @@ def cached_url_opener( continue raise exn except Exception as exn: - exn.args = exn.args + (url,) + exn.args = exn.args + (url, ) if handler(exn): continue else: @@ -166,14 +165,13 @@ def cached_url_opener( def cached_tarfile_samples( - src, - handler=reraise_exception, - cache_size=-1, - cache_dir=None, - verbose=False, - url_to_name=pipe_cleaner, - always=False, -): + src, + handler=reraise_exception, + cache_size=-1, + cache_dir=None, + verbose=False, + url_to_name=pipe_cleaner, + always=False, ): streams = cached_url_opener( src, handler=handler, @@ -181,8 +179,7 @@ def cached_tarfile_samples( cache_dir=cache_dir, verbose=verbose, url_to_name=url_to_name, - always=always, - ) + always=always, ) samples = tar_file_and_group_expander(streams, handler=handler) return samples diff --git a/paddlespeech/audio/streamdata/compat.py b/paddlespeech/audio/streamdata/compat.py index deda5338..9012eeb1 100644 --- a/paddlespeech/audio/streamdata/compat.py +++ b/paddlespeech/audio/streamdata/compat.py @@ -2,17 +2,17 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset -from dataclasses import dataclass -from itertools import islice -from typing import List - -import braceexpand, yaml +import yaml from . import autodecode -from . import cache, filters, shardlists, tariterators +from . import cache +from . import filters +from . import shardlists +from . import tariterators from .filters import reraise_exception +from .paddle_utils import DataLoader +from .paddle_utils import IterableDataset from .pipeline import DataPipeline -from .paddle_utils import DataLoader, IterableDataset class FluidInterface: @@ -26,7 +26,8 @@ class FluidInterface: return self.compose(filters.unbatched()) def listed(self, batchsize, partial=True): - return self.compose(filters.batched(), batchsize=batchsize, collation_fn=None) + return self.compose( + filters.batched(), batchsize=batchsize, collation_fn=None) def unlisted(self): return self.compose(filters.unlisted()) @@ -43,9 +44,19 @@ class FluidInterface: def map(self, f, handler=reraise_exception): return self.compose(filters.map(f, handler=handler)) - def decode(self, *args, pre=None, post=None, only=None, partial=False, handler=reraise_exception): - handlers = [autodecode.ImageHandler(x) if isinstance(x, str) else x for x in args] - decoder = autodecode.Decoder(handlers, pre=pre, post=post, only=only, partial=partial) + def decode(self, + *args, + pre=None, + post=None, + only=None, + partial=False, + handler=reraise_exception): + handlers = [ + autodecode.ImageHandler(x) if isinstance(x, str) else x + for x in args + ] + decoder = autodecode.Decoder( + handlers, pre=pre, post=post, only=only, partial=partial) return self.map(decoder, handler=handler) def map_dict(self, handler=reraise_exception, **kw): @@ -80,12 +91,12 @@ class FluidInterface: def audio_data_filter(self, *args, **kw): return self.compose(filters.audio_data_filter(*args, **kw)) - + def audio_tokenize(self, *args, **kw): return self.compose(filters.audio_tokenize(*args, **kw)) def resample(self, *args, **kw): - return self.compose(filters.resample(*args, **kw)) + return self.compose(filters.resample(*args, **kw)) def audio_compute_fbank(self, *args, **kw): return self.compose(filters.audio_compute_fbank(*args, **kw)) @@ -102,27 +113,28 @@ class FluidInterface: def audio_cmvn(self, cmvn_file): return self.compose(filters.audio_cmvn(cmvn_file)) + class WebDataset(DataPipeline, FluidInterface): """Small fluid-interface wrapper for DataPipeline.""" def __init__( - self, - urls, - handler=reraise_exception, - resampled=False, - repeat=False, - shardshuffle=None, - cache_size=0, - cache_dir=None, - detshuffle=False, - nodesplitter=shardlists.single_node_only, - verbose=False, - ): + self, + urls, + handler=reraise_exception, + resampled=False, + repeat=False, + shardshuffle=None, + cache_size=0, + cache_dir=None, + detshuffle=False, + nodesplitter=shardlists.single_node_only, + verbose=False, ): super().__init__() if isinstance(urls, IterableDataset): assert not resampled self.append(urls) - elif isinstance(urls, str) and (urls.endswith(".yaml") or urls.endswith(".yml")): + elif isinstance(urls, str) and (urls.endswith(".yaml") or + urls.endswith(".yml")): with (open(urls)) as stream: spec = yaml.safe_load(stream) assert "datasets" in spec @@ -152,9 +164,7 @@ class WebDataset(DataPipeline, FluidInterface): handler=handler, verbose=verbose, cache_size=cache_size, - cache_dir=cache_dir, - ) - ) + cache_dir=cache_dir, )) class FluidWrapper(DataPipeline, FluidInterface): diff --git a/paddlespeech/audio/streamdata/extradatasets.py b/paddlespeech/audio/streamdata/extradatasets.py index e6d61772..76361c24 100644 --- a/paddlespeech/audio/streamdata/extradatasets.py +++ b/paddlespeech/audio/streamdata/extradatasets.py @@ -5,20 +5,10 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - - """Train PyTorch models directly from POSIX tar archive. Code works locally or over HTTP connections. """ - -import itertools as itt -import os -import random -import sys - -import braceexpand - from . import utils from .paddle_utils import IterableDataset from .utils import PipelineStage @@ -63,8 +53,7 @@ class repeatedly(IterableDataset, PipelineStage): return utils.repeatedly( source, nepochs=self.nepochs, - nbatches=self.nbatches, - ) + nbatches=self.nbatches, ) class with_epoch(IterableDataset): diff --git a/paddlespeech/audio/streamdata/filters.py b/paddlespeech/audio/streamdata/filters.py index 82b9c6ba..68d6830b 100644 --- a/paddlespeech/audio/streamdata/filters.py +++ b/paddlespeech/audio/streamdata/filters.py @@ -3,7 +3,6 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - # Modified from https://github.com/webdataset/webdataset # Modified from wenet(https://github.com/wenet-e2e/wenet) """A collection of iterators for data transformations. @@ -12,28 +11,29 @@ These functions are plain iterator functions. You can find curried versions in webdataset.filters, and you can find IterableDataset wrappers in webdataset.processing. """ - import io -from fnmatch import fnmatch +import itertools +import os +import random import re -import itertools, os, random, sys, time -from functools import reduce, wraps +import sys +import time +from fnmatch import fnmatch +from functools import reduce -import numpy as np +import paddle from . import autodecode -from . import utils -from .paddle_utils import PaddleTensor -from .utils import PipelineStage - +from . import utils from .. import backends from ..compliance import kaldi -import paddle from ..transform.cmvn import GlobalCMVN -from ..utils.tensor_utils import pad_sequence -from ..transform.spec_augment import time_warp -from ..transform.spec_augment import time_mask from ..transform.spec_augment import freq_mask +from ..transform.spec_augment import time_mask +from ..transform.spec_augment import time_warp +from ..utils.tensor_utils import pad_sequence +from .utils import PipelineStage + class FilterFunction(object): """Helper class for currying pipeline stages. @@ -159,10 +159,12 @@ def transform_with(sample, transformers): result[i] = f(sample[i]) return result + ### # Iterators ### + def _info(data, fmt=None, n=3, every=-1, width=50, stream=sys.stderr, name=""): """Print information about the samples that are passing through. @@ -278,10 +280,16 @@ def _log_keys(data, logfile=None): log_keys = pipelinefilter(_log_keys) +def _minedecode(x): + if isinstance(x, str): + return autodecode.imagehandler(x) + else: + return x + + def _decode(data, *args, handler=reraise_exception, **kw): """Decode data based on the decoding functions given as arguments.""" - - decoder = lambda x: autodecode.imagehandler(x) if isinstance(x, str) else x + decoder = _minedecode handlers = [decoder(x) for x in args] f = autodecode.Decoder(handlers, **kw) @@ -325,15 +333,24 @@ def _rename(data, handler=reraise_exception, keep=True, **kw): for sample in data: try: if not keep: - yield {k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()} + yield { + k: getfirst(sample, v, missing_is_error=True) + for k, v in kw.items() + } else: def listify(v): return v.split(";") if isinstance(v, str) else v to_be_replaced = {x for v in kw.values() for x in listify(v)} - result = {k: v for k, v in sample.items() if k not in to_be_replaced} - result.update({k: getfirst(sample, v, missing_is_error=True) for k, v in kw.items()}) + result = { + k: v + for k, v in sample.items() if k not in to_be_replaced + } + result.update({ + k: getfirst(sample, v, missing_is_error=True) + for k, v in kw.items() + }) yield result except Exception as exn: if handler(exn): @@ -381,7 +398,11 @@ def _map_dict(data, handler=reraise_exception, **kw): map_dict = pipelinefilter(_map_dict) -def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, none_is_error=None): +def _to_tuple(data, + *args, + handler=reraise_exception, + missing_is_error=True, + none_is_error=None): """Convert dict samples to tuples.""" if none_is_error is None: none_is_error = missing_is_error @@ -390,7 +411,10 @@ def _to_tuple(data, *args, handler=reraise_exception, missing_is_error=True, non for sample in data: try: - result = tuple([getfirst(sample, f, missing_is_error=missing_is_error) for f in args]) + result = tuple([ + getfirst(sample, f, missing_is_error=missing_is_error) + for f in args + ]) if none_is_error and any(x is None for x in result): raise ValueError(f"to_tuple {args} got {sample.keys()}") yield result @@ -463,19 +487,28 @@ rsample = pipelinefilter(_rsample) slice = pipelinefilter(itertools.islice) -def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=False): +def _extract_keys(source, + *patterns, + duplicate_is_error=True, + ignore_missing=False): for sample in source: result = [] for pattern in patterns: - pattern = pattern.split(";") if isinstance(pattern, str) else pattern - matches = [x for x in sample.keys() if any(fnmatch("." + x, p) for p in pattern)] + pattern = pattern.split(";") if isinstance(pattern, + str) else pattern + matches = [ + x for x in sample.keys() + if any(fnmatch("." + x, p) for p in pattern) + ] if len(matches) == 0: if ignore_missing: continue else: - raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.") + raise ValueError( + f"Cannot find {pattern} in sample keys {sample.keys()}.") if len(matches) > 1 and duplicate_is_error: - raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.") + raise ValueError( + f"Multiple sample keys {sample.keys()} match {pattern}.") value = sample[matches[0]] result.append(value) yield tuple(result) @@ -484,7 +517,12 @@ def _extract_keys(source, *patterns, duplicate_is_error=True, ignore_missing=Fal extract_keys = pipelinefilter(_extract_keys) -def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicate_is_error=True, **kw): +def _rename_keys(source, + *args, + keep_unselected=False, + must_match=True, + duplicate_is_error=True, + **kw): renamings = [(pattern, output) for output, pattern in args] renamings += [(pattern, output) for output, pattern in kw.items()] for sample in source: @@ -504,11 +542,15 @@ def _rename_keys(source, *args, keep_unselected=False, must_match=True, duplicat continue if new_name in new_sample: if duplicate_is_error: - raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.") + raise ValueError( + f"Duplicate value in sample {sample.keys()} after rename." + ) continue new_sample[new_name] = value if must_match and not all(matched.values()): - raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).") + raise ValueError( + f"Not all patterns ({matched}) matched sample keys ({sample.keys()})." + ) yield new_sample @@ -541,18 +583,18 @@ def find_decoder(decoders, path): if fname.startswith("__"): return lambda x: x for pattern, fun in decoders[::-1]: - if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), pattern): + if fnmatch(fname.lower(), pattern) or fnmatch("." + fname.lower(), + pattern): return fun return None def _xdecode( - source, - *args, - must_decode=True, - defaults=default_decoders, - **kw, -): + source, + *args, + must_decode=True, + defaults=default_decoders, + **kw, ): decoders = list(defaults) + list(args) decoders += [("*." + k, v) for k, v in kw.items()] for sample in source: @@ -575,18 +617,18 @@ def _xdecode( new_sample[path] = value yield new_sample -xdecode = pipelinefilter(_xdecode) +xdecode = pipelinefilter(_xdecode) def _audio_data_filter(source, - frame_shift=10, - max_length=10240, - min_length=10, - token_max_length=200, - token_min_length=1, - min_output_input_ratio=0.0005, - max_output_input_ratio=1): + frame_shift=10, + max_length=10240, + min_length=10, + token_max_length=200, + token_min_length=1, + min_output_input_ratio=0.0005, + max_output_input_ratio=1): """ Filter sample according to feature and label length Inplace operation. @@ -613,7 +655,8 @@ def _audio_data_filter(source, assert 'wav' in sample assert 'label' in sample # sample['wav'] is paddle.Tensor, we have 100 frames every second (default) - num_frames = sample['wav'].shape[1] / sample['sample_rate'] * (1000 / frame_shift) + num_frames = sample['wav'].shape[1] / sample['sample_rate'] * ( + 1000 / frame_shift) if num_frames < min_length: continue if num_frames > max_length: @@ -629,13 +672,15 @@ def _audio_data_filter(source, continue yield sample + audio_data_filter = pipelinefilter(_audio_data_filter) + def _audio_tokenize(source, - symbol_table, - bpe_model=None, - non_lang_syms=None, - split_with_space=False): + symbol_table, + bpe_model=None, + non_lang_syms=None, + split_with_space=False): """ Decode text to chars or BPE Inplace operation @@ -693,8 +738,10 @@ def _audio_tokenize(source, sample['label'] = label yield sample + audio_tokenize = pipelinefilter(_audio_tokenize) + def _audio_resample(source, resample_rate=16000): """ Resample data. Inplace operation. @@ -713,18 +760,22 @@ def _audio_resample(source, resample_rate=16000): waveform = sample['wav'] if sample_rate != resample_rate: sample['sample_rate'] = resample_rate - sample['wav'] = paddle.to_tensor(backends.soundfile_backend.resample( - waveform.numpy(), src_sr = sample_rate, target_sr = resample_rate - )) + sample['wav'] = paddle.to_tensor( + backends.soundfile_backend.resample( + waveform.numpy(), + src_sr=sample_rate, + target_sr=resample_rate)) yield sample + audio_resample = pipelinefilter(_audio_resample) + def _audio_compute_fbank(source, - num_mel_bins=80, - frame_length=25, - frame_shift=10, - dither=0.0): + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0): """ Extract fbank Args: @@ -746,30 +797,33 @@ def _audio_compute_fbank(source, waveform = sample['wav'] waveform = waveform * (1 << 15) # Only keep fname, feat, label - mat = kaldi.fbank(waveform, - n_mels=num_mel_bins, - frame_length=frame_length, - frame_shift=frame_shift, - dither=dither, - energy_floor=0.0, - sr=sample_rate) + mat = kaldi.fbank( + waveform, + n_mels=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither, + energy_floor=0.0, + sr=sample_rate) yield dict(fname=sample['fname'], label=sample['label'], feat=mat) audio_compute_fbank = pipelinefilter(_audio_compute_fbank) -def _audio_spec_aug(source, - max_w=5, - w_inplace=True, - w_mode="PIL", - max_f=30, - num_f_mask=2, - f_inplace=True, - f_replace_with_zero=False, - max_t=40, - num_t_mask=2, - t_inplace=True, - t_replace_with_zero=False,): + +def _audio_spec_aug( + source, + max_w=5, + w_inplace=True, + w_mode="PIL", + max_f=30, + num_f_mask=2, + f_inplace=True, + f_replace_with_zero=False, + max_t=40, + num_t_mask=2, + t_inplace=True, + t_replace_with_zero=False, ): """ Do spec augmentation Inplace operation @@ -793,12 +847,23 @@ def _audio_spec_aug(source, for sample in source: x = sample['feat'] x = x.numpy() - x = time_warp(x, max_time_warp=max_w, inplace = w_inplace, mode= w_mode) - x = freq_mask(x, F = max_f, n_mask = num_f_mask, inplace = f_inplace, replace_with_zero = f_replace_with_zero) - x = time_mask(x, T = max_t, n_mask = num_t_mask, inplace = t_inplace, replace_with_zero = t_replace_with_zero) + x = time_warp(x, max_time_warp=max_w, inplace=w_inplace, mode=w_mode) + x = freq_mask( + x, + F=max_f, + n_mask=num_f_mask, + inplace=f_inplace, + replace_with_zero=f_replace_with_zero) + x = time_mask( + x, + T=max_t, + n_mask=num_t_mask, + inplace=t_inplace, + replace_with_zero=t_replace_with_zero) sample['feat'] = paddle.to_tensor(x, dtype=paddle.float32) yield sample + audio_spec_aug = pipelinefilter(_audio_spec_aug) @@ -829,8 +894,10 @@ def _sort(source, sort_size=500): for x in buf: yield x + sort = pipelinefilter(_sort) + def _batched(source, batch_size=16): """ Static batch the data by `batch_size` @@ -850,8 +917,10 @@ def _batched(source, batch_size=16): if len(buf) > 0: yield buf + batched = pipelinefilter(_batched) + def dynamic_batched(source, max_frames_in_batch=12000): """ Dynamic batch the data until the total frames in batch reach `max_frames_in_batch` @@ -892,8 +961,8 @@ def _audio_padding(source): """ for sample in source: assert isinstance(sample, list) - feats_length = paddle.to_tensor([x['feat'].shape[0] for x in sample], - dtype="int64") + feats_length = paddle.to_tensor( + [x['feat'].shape[0] for x in sample], dtype="int64") order = paddle.argsort(feats_length, descending=True) feats_lengths = paddle.to_tensor( [sample[i]['feat'].shape[0] for i in order], dtype="int64") @@ -902,20 +971,20 @@ def _audio_padding(source): sorted_labels = [ paddle.to_tensor(sample[i]['label'], dtype="int32") for i in order ] - label_lengths = paddle.to_tensor([x.shape[0] for x in sorted_labels], - dtype="int64") - padded_feats = pad_sequence(sorted_feats, - batch_first=True, - padding_value=0) - padding_labels = pad_sequence(sorted_labels, - batch_first=True, - padding_value=-1) - - yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths = paddle.to_tensor( + [x.shape[0] for x in sorted_labels], dtype="int64") + padded_feats = pad_sequence( + sorted_feats, batch_first=True, padding_value=0) + padding_labels = pad_sequence( + sorted_labels, batch_first=True, padding_value=-1) + + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, label_lengths) + audio_padding = pipelinefilter(_audio_padding) + def _audio_cmvn(source, cmvn_file): global_cmvn = GlobalCMVN(cmvn_file) for batch in source: @@ -923,13 +992,16 @@ def _audio_cmvn(source, cmvn_file): padded_feats = padded_feats.numpy() padded_feats = global_cmvn(padded_feats) padded_feats = paddle.to_tensor(padded_feats, dtype=paddle.float32) - yield (sorted_keys, padded_feats, feats_lengths, padding_labels, - label_lengths) + yield (sorted_keys, padded_feats, feats_lengths, padding_labels, + label_lengths) + audio_cmvn = pipelinefilter(_audio_cmvn) + def _placeholder(source): for data in source: yield data + placeholder = pipelinefilter(_placeholder) diff --git a/paddlespeech/audio/streamdata/gopen.py b/paddlespeech/audio/streamdata/gopen.py index 457d048a..60a43460 100644 --- a/paddlespeech/audio/streamdata/gopen.py +++ b/paddlespeech/audio/streamdata/gopen.py @@ -3,12 +3,12 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - - """Open URLs by calling subcommands.""" - -import os, sys, re -from subprocess import PIPE, Popen +import os +import re +import sys +from subprocess import PIPE +from subprocess import Popen from urllib.parse import urlparse # global used for printing additional node information during verbose output @@ -31,14 +31,13 @@ class Pipe: """ def __init__( - self, - *args, - mode=None, - timeout=7200.0, - ignore_errors=False, - ignore_status=[], - **kw, - ): + self, + *args, + mode=None, + timeout=7200.0, + ignore_errors=False, + ignore_status=[], + **kw, ): """Create an IO Pipe.""" self.ignore_errors = ignore_errors self.ignore_status = [0] + ignore_status @@ -75,8 +74,7 @@ class Pipe: if verbose: print( f"pipe exit [{self.status} {os.getpid()}:{self.proc.pid}] {self.args} {info}", - file=sys.stderr, - ) + file=sys.stderr, ) if self.status not in self.ignore_status and not self.ignore_errors: raise Exception(f"{self.args}: exit {self.status} (read) {info}") @@ -114,9 +112,11 @@ class Pipe: self.close() -def set_options( - obj, timeout=None, ignore_errors=None, ignore_status=None, handler=None -): +def set_options(obj, + timeout=None, + ignore_errors=None, + ignore_status=None, + handler=None): """Set options for Pipes. This function can be called on any stream. It will set pipe options only @@ -168,16 +168,14 @@ def gopen_pipe(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141], - ) # skipcq: BAN-B604 + ignore_status=[141], ) # skipcq: BAN-B604 elif mode[0] == "w": return Pipe( cmd, mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141], - ) # skipcq: BAN-B604 + ignore_status=[141], ) # skipcq: BAN-B604 else: raise ValueError(f"{mode}: unknown mode") @@ -196,8 +194,7 @@ def gopen_curl(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 23], - ) # skipcq: BAN-B604 + ignore_status=[141, 23], ) # skipcq: BAN-B604 elif mode[0] == "w": cmd = f"curl -s -L -T - '{url}'" return Pipe( @@ -205,8 +202,7 @@ def gopen_curl(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 26], - ) # skipcq: BAN-B604 + ignore_status=[141, 26], ) # skipcq: BAN-B604 else: raise ValueError(f"{mode}: unknown mode") @@ -226,15 +222,13 @@ def gopen_htgs(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 23], - ) # skipcq: BAN-B604 + ignore_status=[141, 23], ) # skipcq: BAN-B604 elif mode[0] == "w": raise ValueError(f"{mode}: cannot write") else: raise ValueError(f"{mode}: unknown mode") - def gopen_gsutil(url, mode="rb", bufsize=8192): """Open a URL with `curl`. @@ -249,8 +243,7 @@ def gopen_gsutil(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 23], - ) # skipcq: BAN-B604 + ignore_status=[141, 23], ) # skipcq: BAN-B604 elif mode[0] == "w": cmd = f"gsutil cp - '{url}'" return Pipe( @@ -258,13 +251,11 @@ def gopen_gsutil(url, mode="rb", bufsize=8192): mode=mode, shell=True, bufsize=bufsize, - ignore_status=[141, 26], - ) # skipcq: BAN-B604 + ignore_status=[141, 26], ) # skipcq: BAN-B604 else: raise ValueError(f"{mode}: unknown mode") - def gopen_error(url, *args, **kw): """Raise a value error. @@ -285,8 +276,7 @@ gopen_schemes = dict( ftps=gopen_curl, scp=gopen_curl, gs=gopen_gsutil, - htgs=gopen_htgs, -) + htgs=gopen_htgs, ) def gopen(url, mode="rb", bufsize=8192, **kw): diff --git a/paddlespeech/audio/streamdata/handlers.py b/paddlespeech/audio/streamdata/handlers.py index 7f3d28b6..0173e537 100644 --- a/paddlespeech/audio/streamdata/handlers.py +++ b/paddlespeech/audio/streamdata/handlers.py @@ -3,7 +3,6 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - """Pluggable exception handlers. These are functions that take an exception as an argument and then return... @@ -14,8 +13,8 @@ These are functions that take an exception as an argument and then return... They are used as handler= arguments in much of the library. """ - -import time, warnings +import time +import warnings def reraise_exception(exn): diff --git a/paddlespeech/audio/streamdata/mix.py b/paddlespeech/audio/streamdata/mix.py index 7d790f00..37556ed9 100644 --- a/paddlespeech/audio/streamdata/mix.py +++ b/paddlespeech/audio/streamdata/mix.py @@ -5,17 +5,12 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - """Classes for mixing samples from multiple sources.""" - -import itertools, os, random, time, sys -from functools import reduce, wraps +import random import numpy as np -from . import autodecode, utils -from .paddle_utils import PaddleTensor, IterableDataset -from .utils import PipelineStage +from .paddle_utils import IterableDataset def round_robin_shortest(*sources): diff --git a/paddlespeech/audio/streamdata/paddle_utils.py b/paddlespeech/audio/streamdata/paddle_utils.py index 02bc4c84..c2ad8756 100644 --- a/paddlespeech/audio/streamdata/paddle_utils.py +++ b/paddlespeech/audio/streamdata/paddle_utils.py @@ -5,12 +5,11 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - """Mock implementations of paddle interfaces when paddle is not available.""" - try: - from paddle.io import DataLoader, IterableDataset + from paddle.io import DataLoader + from paddle.io import IterableDataset except ModuleNotFoundError: class IterableDataset: @@ -22,12 +21,3 @@ except ModuleNotFoundError: """Empty implementation of DataLoader when paddle is not available.""" pass - -try: - from paddle import Tensor as PaddleTensor -except ModuleNotFoundError: - - class TorchTensor: - """Empty implementation of PaddleTensor when paddle is not available.""" - - pass diff --git a/paddlespeech/audio/streamdata/pipeline.py b/paddlespeech/audio/streamdata/pipeline.py index 7339a762..ff16760a 100644 --- a/paddlespeech/audio/streamdata/pipeline.py +++ b/paddlespeech/audio/streamdata/pipeline.py @@ -3,15 +3,12 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset #%% -import copy, os, random, sys, time -from dataclasses import dataclass +import copy +import sys from itertools import islice -from typing import List -import braceexpand, yaml - -from .handlers import reraise_exception -from .paddle_utils import DataLoader, IterableDataset +from .paddle_utils import DataLoader +from .paddle_utils import IterableDataset from .utils import PipelineStage @@ -22,8 +19,7 @@ def add_length_method(obj): Combined = type( obj.__class__.__name__ + "_Length", (obj.__class__, IterableDataset), - {"__len__": length}, - ) + {"__len__": length}, ) obj.__class__ = Combined return obj diff --git a/paddlespeech/audio/streamdata/shardlists.py b/paddlespeech/audio/streamdata/shardlists.py index cfaf9a64..54f50105 100644 --- a/paddlespeech/audio/streamdata/shardlists.py +++ b/paddlespeech/audio/streamdata/shardlists.py @@ -4,28 +4,30 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - # Modified from https://github.com/webdataset/webdataset - """Train PyTorch models directly from POSIX tar archive. Code works locally or over HTTP connections. """ - -import os, random, sys, time -from dataclasses import dataclass, field +import os +import random +import sys +import time +from dataclasses import dataclass +from dataclasses import field from itertools import islice from typing import List -import braceexpand, yaml +import braceexpand +import yaml from . import utils +from ..utils.log import Logger from .filters import pipelinefilter from .paddle_utils import IterableDataset +logger = Logger(__name__) -from ..utils.log import Logger -logger = Logger(__name__) def expand_urls(urls): if isinstance(urls, str): urllist = urls.split("::") @@ -64,7 +66,8 @@ class SimpleShardList(IterableDataset): def split_by_node(src, group=None): - rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + rank, world_size, worker, num_workers = utils.paddle_worker_info( + group=group) logger.info(f"world_size:{world_size}, rank:{rank}") if world_size > 1: for s in islice(src, rank, None, world_size): @@ -75,9 +78,11 @@ def split_by_node(src, group=None): def single_node_only(src, group=None): - rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group) + rank, world_size, worker, num_workers = utils.paddle_worker_info( + group=group) if world_size > 1: - raise ValueError("input pipeline needs to be reconfigured for multinode training") + raise ValueError( + "input pipeline needs to be reconfigured for multinode training") for s in src: yield s @@ -104,7 +109,8 @@ def resampled_(src, n=sys.maxsize): rng = random.Random(seed) print("# resampled loading", file=sys.stderr) items = list(src) - print(f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr) + print( + f"# resampled got {len(items)} samples, yielding {n}", file=sys.stderr) for i in range(n): yield rng.choice(items) @@ -118,7 +124,9 @@ def non_empty(src): yield s count += 1 if count == 0: - raise ValueError("pipeline stage received no data at all and this was declared as an error") + raise ValueError( + "pipeline stage received no data at all and this was declared as an error" + ) @dataclass @@ -138,10 +146,6 @@ def expand(s): return os.path.expanduser(os.path.expandvars(s)) -class MultiShardSample(IterableDataset): - def __init__(self, fname): - """Construct a shardlist from multiple sources using a YAML spec.""" - self.epoch = -1 class MultiShardSample(IterableDataset): def __init__(self, fname): """Construct a shardlist from multiple sources using a YAML spec.""" @@ -156,20 +160,23 @@ class MultiShardSample(IterableDataset): else: with open(fname) as stream: spec = yaml.safe_load(stream) - assert set(spec.keys()).issubset(set("prefix datasets buckets".split())), list(spec.keys()) + assert set(spec.keys()).issubset( + set("prefix datasets buckets".split())), list(spec.keys()) prefix = expand(spec.get("prefix", "")) self.sources = [] for ds in spec["datasets"]: - assert set(ds.keys()).issubset(set("buckets name shards resample choose".split())), list( - ds.keys() - ) + assert set(ds.keys()).issubset( + set("buckets name shards resample choose".split())), list( + ds.keys()) buckets = ds.get("buckets", spec.get("buckets", [])) if isinstance(buckets, str): buckets = [buckets] buckets = [expand(s) for s in buckets] if buckets == []: buckets = [""] - assert len(buckets) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented" + assert len( + buckets + ) == 1, f"{buckets}: FIXME support for multiple buckets unimplemented" bucket = buckets[0] name = ds.get("name", "@" + bucket) urls = ds["shards"] @@ -177,15 +184,19 @@ class MultiShardSample(IterableDataset): urls = [urls] # urls = [u for url in urls for u in braceexpand.braceexpand(url)] urls = [ - prefix + os.path.join(bucket, u) for url in urls for u in braceexpand.braceexpand(expand(url)) + prefix + os.path.join(bucket, u) + for url in urls for u in braceexpand.braceexpand(expand(url)) ] resample = ds.get("resample", -1) nsample = ds.get("choose", -1) if nsample > len(urls): - raise ValueError(f"perepoch {nsample} must be no greater than the number of shards") + raise ValueError( + f"perepoch {nsample} must be no greater than the number of shards" + ) if (nsample > 0) and (resample > 0): raise ValueError("specify only one of perepoch or choose") - entry = MSSource(name=name, urls=urls, perepoch=nsample, resample=resample) + entry = MSSource( + name=name, urls=urls, perepoch=nsample, resample=resample) self.sources.append(entry) print(f"# {name} {len(urls)} {nsample}", file=sys.stderr) @@ -203,7 +214,7 @@ class MultiShardSample(IterableDataset): # sample without replacement l = list(source.urls) self.rng.shuffle(l) - l = l[: source.perepoch] + l = l[:source.perepoch] else: l = list(source.urls) result += l @@ -227,12 +238,11 @@ class ResampledShards(IterableDataset): """An iterable dataset yielding a list of urls.""" def __init__( - self, - urls, - nshards=sys.maxsize, - worker_seed=None, - deterministic=False, - ): + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=False, ): """Sample shards from the shard list with replacement. :param urls: a list of URLs as a Python list or brace notation string @@ -252,7 +262,8 @@ class ResampledShards(IterableDataset): if self.deterministic: seed = utils.make_seed(self.worker_seed(), self.epoch) else: - seed = utils.make_seed(self.worker_seed(), self.epoch, os.getpid(), time.time_ns(), os.urandom(4)) + seed = utils.make_seed(self.worker_seed(), self.epoch, + os.getpid(), time.time_ns(), os.urandom(4)) if os.environ.get("WDS_SHOW_SEED", "0") == "1": print(f"# ResampledShards seed {seed}") self.rng = random.Random(seed) diff --git a/paddlespeech/audio/streamdata/tariterators.py b/paddlespeech/audio/streamdata/tariterators.py index b1616918..79b81c0c 100644 --- a/paddlespeech/audio/streamdata/tariterators.py +++ b/paddlespeech/audio/streamdata/tariterators.py @@ -3,13 +3,12 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). - # Modified from https://github.com/webdataset/webdataset # Modified from wenet(https://github.com/wenet-e2e/wenet) - """Low level iteration functions for tar archives.""" - -import random, re, tarfile +import random +import re +import tarfile import braceexpand @@ -27,6 +26,7 @@ import numpy as np AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma']) + def base_plus_ext(path): """Split off all file extensions. @@ -47,12 +47,8 @@ def valid_sample(sample): :param sample: sample to be checked """ - return ( - sample is not None - and isinstance(sample, dict) - and len(list(sample.keys())) > 0 - and not sample.get("__bad__", False) - ) + return (sample is not None and isinstance(sample, dict) and + len(list(sample.keys())) > 0 and not sample.get("__bad__", False)) # FIXME: UNUSED @@ -79,16 +75,16 @@ def url_opener(data, handler=reraise_exception, **kw): sample.update(stream=stream) yield sample except Exception as exn: - exn.args = exn.args + (url,) + exn.args = exn.args + (url, ) if handler(exn): continue else: break -def tar_file_iterator( - fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception -): +def tar_file_iterator(fileobj, + skip_meta=r"__[^/]*__($|/)", + handler=reraise_exception): """Iterate over tar file, yielding filename, content pairs for the given tar stream. :param fileobj: byte stream suitable for tarfile @@ -103,11 +99,8 @@ def tar_file_iterator( continue if fname is None: continue - if ( - "/" not in fname - and fname.startswith(meta_prefix) - and fname.endswith(meta_suffix) - ): + if ("/" not in fname and fname.startswith(meta_prefix) and + fname.endswith(meta_suffix)): # skipping metadata for now continue if skip_meta is not None and re.match(skip_meta, fname): @@ -118,8 +111,10 @@ def tar_file_iterator( assert pos > 0 prefix, postfix = name[:pos], name[pos + 1:] if postfix == 'wav': - waveform, sample_rate = paddlespeech.audio.load(stream.extractfile(tarinfo), normal=False) - result = dict(fname=prefix, wav=waveform, sample_rate = sample_rate) + waveform, sample_rate = paddlespeech.audio.load( + stream.extractfile(tarinfo), normal=False) + result = dict( + fname=prefix, wav=waveform, sample_rate=sample_rate) else: txt = stream.extractfile(tarinfo).read().decode('utf8').strip() result = dict(fname=prefix, txt=txt) @@ -128,16 +123,17 @@ def tar_file_iterator( stream.members = [] except Exception as exn: if hasattr(exn, "args") and len(exn.args) > 0: - exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + exn.args = (exn.args[0] + " @ " + str(fileobj), ) + exn.args[1:] if handler(exn): continue else: break del stream -def tar_file_and_group_iterator( - fileobj, skip_meta=r"__[^/]*__($|/)", handler=reraise_exception -): + +def tar_file_and_group_iterator(fileobj, + skip_meta=r"__[^/]*__($|/)", + handler=reraise_exception): """ Expand a stream of open tar files into a stream of tar file contents. And groups the file with same prefix @@ -167,8 +163,11 @@ def tar_file_and_group_iterator( if postfix == 'txt': example['txt'] = file_obj.read().decode('utf8').strip() elif postfix in AUDIO_FORMAT_SETS: - waveform, sample_rate = paddlespeech.audio.load(file_obj, normal=False) - waveform = paddle.to_tensor(np.expand_dims(np.array(waveform),0), dtype=paddle.float32) + waveform, sample_rate = paddlespeech.audio.load( + file_obj, normal=False) + waveform = paddle.to_tensor( + np.expand_dims(np.array(waveform), 0), + dtype=paddle.float32) example['wav'] = waveform example['sample_rate'] = sample_rate @@ -176,19 +175,21 @@ def tar_file_and_group_iterator( example[postfix] = file_obj.read() except Exception as exn: if hasattr(exn, "args") and len(exn.args) > 0: - exn.args = (exn.args[0] + " @ " + str(fileobj),) + exn.args[1:] + exn.args = (exn.args[0] + " @ " + str(fileobj), + ) + exn.args[1:] if handler(exn): continue else: break valid = False - # logging.warning('error to parse {}'.format(name)) + # logging.warning('error to parse {}'.format(name)) prev_prefix = prefix if prev_prefix is not None: example['fname'] = prev_prefix yield example stream.close() + def tar_file_expander(data, handler=reraise_exception): """Expand a stream of open tar files into a stream of tar file contents. @@ -200,9 +201,8 @@ def tar_file_expander(data, handler=reraise_exception): assert isinstance(source, dict) assert "stream" in source for sample in tar_file_iterator(source["stream"]): - assert ( - isinstance(sample, dict) and "data" in sample and "fname" in sample - ) + assert (isinstance(sample, dict) and "data" in sample and + "fname" in sample) sample["__url__"] = url yield sample except Exception as exn: @@ -213,8 +213,6 @@ def tar_file_expander(data, handler=reraise_exception): break - - def tar_file_and_group_expander(data, handler=reraise_exception): """Expand a stream of open tar files into a stream of tar file contents. @@ -226,9 +224,8 @@ def tar_file_and_group_expander(data, handler=reraise_exception): assert isinstance(source, dict) assert "stream" in source for sample in tar_file_and_group_iterator(source["stream"]): - assert ( - isinstance(sample, dict) and "wav" in sample and "txt" in sample and "fname" in sample - ) + assert (isinstance(sample, dict) and "wav" in sample and + "txt" in sample and "fname" in sample) sample["__url__"] = url yield sample except Exception as exn: @@ -239,7 +236,11 @@ def tar_file_and_group_expander(data, handler=reraise_exception): break -def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): +def group_by_keys(data, + keys=base_plus_ext, + lcase=True, + suffixes=None, + handler=None): """Return function over iterator that groups key, value pairs into samples. :param keys: function that splits the key into key and extension (base_plus_ext) @@ -254,8 +255,8 @@ def group_by_keys(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=N print( prefix, suffix, - current_sample.keys() if isinstance(current_sample, dict) else None, - ) + current_sample.keys() + if isinstance(current_sample, dict) else None, ) if prefix is None: continue if lcase: diff --git a/paddlespeech/audio/streamdata/utils.py b/paddlespeech/audio/streamdata/utils.py index c7294f2b..94dab905 100644 --- a/paddlespeech/audio/streamdata/utils.py +++ b/paddlespeech/audio/streamdata/utils.py @@ -4,22 +4,23 @@ # This file is part of the WebDataset library. # See the LICENSE file for licensing terms (BSD-style). # - # Modified from https://github.com/webdataset/webdataset - """Miscellaneous utility functions.""" - import importlib import itertools as itt import os import re import sys -from typing import Any, Callable, Iterator, Optional, Union +from typing import Any +from typing import Callable +from typing import Iterator +from typing import Union from ..utils.log import Logger logger = Logger(__name__) + def make_seed(*args): seed = 0 for arg in args: @@ -37,7 +38,7 @@ def identity(x: Any) -> Any: return x -def safe_eval(s: str, expr: str = "{}"): +def safe_eval(s: str, expr: str="{}"): """Evaluate the given expression more safely.""" if re.sub("[^A-Za-z0-9_]", "", s) != s: raise ValueError(f"safe_eval: illegal characters in: '{s}'") @@ -54,9 +55,9 @@ def lookup_sym(sym: str, modules: list): return None -def repeatedly0( - loader: Iterator, nepochs: int = sys.maxsize, nbatches: int = sys.maxsize -): +def repeatedly0(loader: Iterator, + nepochs: int=sys.maxsize, + nbatches: int=sys.maxsize): """Repeatedly returns batches from a DataLoader.""" for epoch in range(nepochs): for sample in itt.islice(loader, nbatches): @@ -69,12 +70,11 @@ def guess_batchsize(batch: Union[tuple, list]): def repeatedly( - source: Iterator, - nepochs: int = None, - nbatches: int = None, - nsamples: int = None, - batchsize: Callable[..., int] = guess_batchsize, -): + source: Iterator, + nepochs: int=None, + nbatches: int=None, + nsamples: int=None, + batchsize: Callable[..., int]=guess_batchsize, ): """Repeatedly yield samples from an iterator.""" epoch = 0 batch = 0 @@ -93,6 +93,7 @@ def repeatedly( if nepochs is not None and epoch >= nepochs: return + def paddle_worker_info(group=None): """Return node and worker info for PyTorch and some distributed environments.""" rank = 0 @@ -116,7 +117,7 @@ def paddle_worker_info(group=None): else: try: from paddle.io import get_worker_info - worker_info = paddle.io.get_worker_info() + worker_info = get_worker_info() if worker_info is not None: worker = worker_info.id num_workers = worker_info.num_workers @@ -126,6 +127,7 @@ def paddle_worker_info(group=None): return rank, world_size, worker, num_workers + def paddle_worker_seed(group=None): """Compute a distinct, deterministic RNG seed for each worker and node.""" rank, world_size, worker, num_workers = paddle_worker_info(group=group) diff --git a/paddlespeech/audio/streamdata/writer.py b/paddlespeech/audio/streamdata/writer.py index 7d4f7703..3928a3ba 100644 --- a/paddlespeech/audio/streamdata/writer.py +++ b/paddlespeech/audio/streamdata/writer.py @@ -5,18 +5,24 @@ # See the LICENSE file for licensing terms (BSD-style). # Modified from https://github.com/webdataset/webdataset # - """Classes and functions for writing tar files and WebDataset files.""" - -import io, json, pickle, re, tarfile, time -from typing import Any, Callable, Optional, Union +import io +import json +import pickle +import re +import tarfile +import time +from typing import Any +from typing import Callable +from typing import Optional +from typing import Union import numpy as np from . import gopen -def imageencoder(image: Any, format: str = "PNG"): # skipcq: PYL-W0622 +def imageencoder(image: Any, format: str="PNG"): # skipcq: PYL-W0622 """Compress an image using PIL and return it as a string. Can handle float or uint8 images. @@ -67,6 +73,7 @@ def bytestr(data: Any): return data.encode("ascii") return str(data).encode("ascii") + def paddle_dumps(data: Any): """Dump data into a bytestring using paddle.dumps. @@ -82,6 +89,7 @@ def paddle_dumps(data: Any): paddle.save(data, stream) return stream.getvalue() + def numpy_dumps(data: np.ndarray): """Dump data into a bytestring using numpy npy format. @@ -139,9 +147,8 @@ def add_handlers(d, keys, value): def make_handlers(): """Create a list of handlers for encoding data.""" handlers = {} - add_handlers( - handlers, "cls cls2 class count index inx id", lambda x: str(x).encode("ascii") - ) + add_handlers(handlers, "cls cls2 class count index inx id", + lambda x: str(x).encode("ascii")) add_handlers(handlers, "txt text transcript", lambda x: x.encode("utf-8")) add_handlers(handlers, "html htm", lambda x: x.encode("utf-8")) add_handlers(handlers, "pyd pickle", pickle.dumps) @@ -152,7 +159,8 @@ def make_handlers(): add_handlers(handlers, "json jsn", lambda x: json.dumps(x).encode("utf-8")) add_handlers(handlers, "mp msgpack msg", mp_dumps) add_handlers(handlers, "cbor", cbor_dumps) - add_handlers(handlers, "jpg jpeg img image", lambda data: imageencoder(data, "jpg")) + add_handlers(handlers, "jpg jpeg img image", + lambda data: imageencoder(data, "jpg")) add_handlers(handlers, "png", lambda data: imageencoder(data, "png")) add_handlers(handlers, "pbm", lambda data: imageencoder(data, "pbm")) add_handlers(handlers, "pgm", lambda data: imageencoder(data, "pgm")) @@ -192,7 +200,8 @@ def encode_based_on_extension(sample: dict, handlers: dict): :param handlers: handlers for encoding """ return { - k: encode_based_on_extension1(v, k, handlers) for k, v in list(sample.items()) + k: encode_based_on_extension1(v, k, handlers) + for k, v in list(sample.items()) } @@ -258,15 +267,14 @@ class TarWriter: """ def __init__( - self, - fileobj, - user: str = "bigdata", - group: str = "bigdata", - mode: int = 0o0444, - compress: Optional[bool] = None, - encoder: Union[None, bool, Callable] = True, - keep_meta: bool = False, - ): + self, + fileobj, + user: str="bigdata", + group: str="bigdata", + mode: int=0o0444, + compress: Optional[bool]=None, + encoder: Union[None, bool, Callable]=True, + keep_meta: bool=False, ): """Create a tar writer. :param fileobj: stream to write data to @@ -330,8 +338,7 @@ class TarWriter: continue if not isinstance(v, (bytes, bytearray, memoryview)): raise ValueError( - f"{k} doesn't map to a bytes after encoding ({type(v)})" - ) + f"{k} doesn't map to a bytes after encoding ({type(v)})") key = obj["__key__"] for k in sorted(obj.keys()): if k == "__key__": @@ -349,7 +356,8 @@ class TarWriter: ti.uname = self.user ti.gname = self.group if not isinstance(v, (bytes, bytearray, memoryview)): - raise ValueError(f"converter didn't yield bytes: {k}, {type(v)}") + raise ValueError( + f"converter didn't yield bytes: {k}, {type(v)}") stream = io.BytesIO(v) self.tarstream.addfile(ti, stream) total += ti.size @@ -360,14 +368,13 @@ class ShardWriter: """Like TarWriter but splits into multiple shards.""" def __init__( - self, - pattern: str, - maxcount: int = 100000, - maxsize: float = 3e9, - post: Optional[Callable] = None, - start_shard: int = 0, - **kw, - ): + self, + pattern: str, + maxcount: int=100000, + maxsize: float=3e9, + post: Optional[Callable]=None, + start_shard: int=0, + **kw, ): """Create a ShardWriter. :param pattern: output file pattern @@ -400,8 +407,7 @@ class ShardWriter: self.fname, self.count, "%.1f GB" % (self.size / 1e9), - self.total, - ) + self.total, ) self.shard += 1 stream = open(self.fname, "wb") self.tarstream = TarWriter(stream, **self.kw) @@ -413,11 +419,8 @@ class ShardWriter: :param obj: sample to be written """ - if ( - self.tarstream is None - or self.count >= self.maxcount - or self.size >= self.maxsize - ): + if (self.tarstream is None or self.count >= self.maxcount or + self.size >= self.maxsize): self.next_stream() size = self.tarstream.write(obj) self.count += 1 diff --git a/paddlespeech/audio/text/text_featurizer.py b/paddlespeech/audio/text/text_featurizer.py index 91c4d75c..bcd6df54 100644 --- a/paddlespeech/audio/text/text_featurizer.py +++ b/paddlespeech/audio/text/text_featurizer.py @@ -17,6 +17,7 @@ from typing import Union import sentencepiece as spm +from ..utils.log import Logger from .utility import BLANK from .utility import EOS from .utility import load_dict @@ -24,7 +25,6 @@ from .utility import MASKCTC from .utility import SOS from .utility import SPACE from .utility import UNK -from ..utils.log import Logger logger = Logger(__name__) diff --git a/paddlespeech/audio/transform/perturb.py b/paddlespeech/audio/transform/perturb.py index 8044dc36..0825caec 100644 --- a/paddlespeech/audio/transform/perturb.py +++ b/paddlespeech/audio/transform/perturb.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) +import io +import os + +import h5py import librosa import numpy +import numpy as np import scipy import soundfile -import io -import os -import h5py -import numpy as np class SoundHDF5File(): """Collecting sound files to a HDF5 file @@ -109,6 +110,7 @@ class SoundHDF5File(): def close(self): self.file.close() + class SpeedPerturbation(): """SpeedPerturbation @@ -558,4 +560,3 @@ class RIRConvolve(): [scipy.convolve(x, r, mode="same") for r in rir], axis=-1) else: return scipy.convolve(x, rir, mode="same") - diff --git a/paddlespeech/audio/transform/spec_augment.py b/paddlespeech/audio/transform/spec_augment.py index 029e7b8f..b2635066 100644 --- a/paddlespeech/audio/transform/spec_augment.py +++ b/paddlespeech/audio/transform/spec_augment.py @@ -14,6 +14,7 @@ # Modified from espnet(https://github.com/espnet/espnet) """Spec Augment module for preprocessing i.e., data augmentation""" import random + import numpy from PIL import Image diff --git a/paddlespeech/cli/executor.py b/paddlespeech/cli/executor.py index 3800c36d..b53eed88 100644 --- a/paddlespeech/cli/executor.py +++ b/paddlespeech/cli/executor.py @@ -191,7 +191,7 @@ class BaseExecutor(ABC): line = line.strip() if not line: continue - k, v = line.split() # space or \t + k, v = line.split() # space or \t job_contents[k] = v return job_contents diff --git a/paddlespeech/s2t/__init__.py b/paddlespeech/s2t/__init__.py index f6476b9a..5fe2e16b 100644 --- a/paddlespeech/s2t/__init__.py +++ b/paddlespeech/s2t/__init__.py @@ -114,6 +114,7 @@ if not hasattr(paddle.Tensor, 'new_full'): paddle.Tensor.new_full = new_full paddle.static.Variable.new_full = new_full + def contiguous(xs: paddle.Tensor) -> paddle.Tensor: return xs diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index cdad3b8f..db60083b 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -25,8 +25,6 @@ import paddle from paddle import distributed as dist from paddlespeech.s2t.frontend.featurizer import TextFeaturizer -from paddlespeech.s2t.io.dataloader import BatchDataLoader -from paddlespeech.s2t.io.dataloader import StreamDataLoader from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.models.u2 import U2Model from paddlespeech.s2t.training.optimizer import OptimizerFactory @@ -109,7 +107,8 @@ class U2Trainer(Trainer): def valid(self): self.model.eval() if not self.use_streamdata: - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -136,7 +135,8 @@ class U2Trainer(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) if not self.use_streamdata: - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += "batch: {}/{}, ".format(i + 1, + len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -157,7 +157,8 @@ class U2Trainer(Trainer): self.before_train() if not self.use_streamdata: - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -225,14 +226,18 @@ class U2Trainer(Trainer): config = self.config.clone() self.use_streamdata = config.get("use_stream_data", False) if self.train: - self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) - self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + self.train_loader = DataLoaderFactory.get_dataloader( + 'train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader( + 'valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) - self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) - self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, + self.args) + self.align_loader = DataLoaderFactory.get_dataloader( + 'align', config, self.args) logger.info("Setup test/align Dataloader!") def setup_model(self): diff --git a/paddlespeech/s2t/exps/u2_kaldi/model.py b/paddlespeech/s2t/exps/u2_kaldi/model.py index cb015c11..073d7429 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/model.py +++ b/paddlespeech/s2t/exps/u2_kaldi/model.py @@ -105,7 +105,8 @@ class U2Trainer(Trainer): def valid(self): self.model.eval() if not self.use_streamdata: - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -133,7 +134,8 @@ class U2Trainer(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) if not self.use_streamdata: - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += "batch: {}/{}, ".format(i + 1, + len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -153,7 +155,8 @@ class U2Trainer(Trainer): self.before_train() if not self.use_streamdata: - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -165,8 +168,8 @@ class U2Trainer(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) if not self.use_streamdata: - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) + msg += "batch : {}/{}, ".format( + batch_index + 1, len(self.train_loader)) msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) msg += "data time: {:>.3f}s, ".format(dataload_time) self.train_batch(batch_index, batch, msg) @@ -204,21 +207,24 @@ class U2Trainer(Trainer): self.use_streamdata = config.get("use_stream_data", False) if self.train: config = self.config.clone() - self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) + self.train_loader = DataLoaderFactory.get_dataloader( + 'train', config, self.args) config = self.config.clone() config['preprocess_config'] = None - self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader( + 'valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: config = self.config.clone() config['preprocess_config'] = None - self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, + self.args) config = self.config.clone() config['preprocess_config'] = None - self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) + self.align_loader = DataLoaderFactory.get_dataloader( + 'align', config, self.args) logger.info("Setup test/align Dataloader!") - def setup_model(self): config = self.config diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index 60382543..d57c4954 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -121,7 +121,8 @@ class U2STTrainer(Trainer): def valid(self): self.model.eval() if not self.use_streamdata: - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 @@ -155,7 +156,8 @@ class U2STTrainer(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) if not self.use_streamdata: - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += "batch: {}/{}, ".format(i + 1, + len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -175,7 +177,8 @@ class U2STTrainer(Trainer): self.before_train() if not self.use_streamdata: - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -248,14 +251,16 @@ class U2STTrainer(Trainer): config['load_transcript'] = load_transcript self.use_streamdata = config.get("use_stream_data", False) if self.train: - self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) - self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + self.train_loader = DataLoaderFactory.get_dataloader( + 'train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader( + 'valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: - self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, + self.args) logger.info("Setup test Dataloader!") - def setup_model(self): config = self.config model_conf = config diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 735d29da..4cc8274f 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -22,17 +22,16 @@ import paddle from paddle.io import BatchSampler from paddle.io import DataLoader from paddle.io import DistributedBatchSampler +from yacs.config import CfgNode +import paddlespeech.audio.streamdata as streamdata +from paddlespeech.audio.text.text_featurizer import TextFeaturizer from paddlespeech.s2t.io.batchfy import make_batchset from paddlespeech.s2t.io.converter import CustomConverter from paddlespeech.s2t.io.dataset import TransformDataset from paddlespeech.s2t.io.reader import LoadInputsAndTargets from paddlespeech.s2t.utils.log import Log -import paddlespeech.audio.streamdata as streamdata -from paddlespeech.audio.text.text_featurizer import TextFeaturizer -from yacs.config import CfgNode - __all__ = ["BatchDataLoader", "StreamDataLoader"] logger = Log(__name__).getlog() @@ -61,6 +60,7 @@ def batch_collate(x): """ return x[0] + def read_preprocess_cfg(preprocess_conf_file): augment_conf = dict() preprocess_cfg = CfgNode(new_allowed=True) @@ -82,7 +82,8 @@ def read_preprocess_cfg(preprocess_conf_file): augment_conf['num_t_mask'] = process['n_mask'] augment_conf['t_inplace'] = process['inplace'] augment_conf['t_replace_with_zero'] = process['replace_with_zero'] - return augment_conf + return augment_conf + class StreamDataLoader(): def __init__(self, @@ -95,12 +96,12 @@ class StreamDataLoader(): frame_length=25, frame_shift=10, dither=0.0, - minlen_in: float=0.0, + minlen_in: float=0.0, maxlen_in: float=float('inf'), minlen_out: float=0.0, maxlen_out: float=float('inf'), resample_rate: int=16000, - shuffle_size: int=10000, + shuffle_size: int=10000, sort_size: int=1000, n_iter_processes: int=1, prefetch_factor: int=2, @@ -116,11 +117,11 @@ class StreamDataLoader(): text_featurizer = TextFeaturizer(unit_type, vocab_filepath) symbol_table = text_featurizer.vocab_dict - self.feat_dim = num_mel_bins - self.vocab_size = text_featurizer.vocab_size - + self.feat_dim = num_mel_bins + self.vocab_size = text_featurizer.vocab_size + augment_conf = read_preprocess_cfg(preprocess_conf) - + # The list of shard shardlist = [] with open(manifest_file, "r") as f: @@ -128,58 +129,68 @@ class StreamDataLoader(): shardlist.append(line.strip()) world_size = 1 try: - world_size = paddle.distributed.get_world_size() + world_size = paddle.distributed.get_world_size() except Exception as e: logger.warninig(e) - logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1") - assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...") + logger.warninig( + "can not get world_size using paddle.distributed.get_world_size(), use world_size=1" + ) + assert len(shardlist) >= world_size, \ + "the length of shard list should >= number of gpus/xpus/..." - update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0)) + update_n_iter_processes = int( + max(min(len(shardlist) / world_size - 1, self.n_iter_processes), 0)) logger.info(f"update_n_iter_processes {update_n_iter_processes}") if update_n_iter_processes != self.n_iter_processes: - self.n_iter_processes = update_n_iter_processes + self.n_iter_processes = update_n_iter_processes logger.info(f"change nun_workers to {self.n_iter_processes}") if self.dist_sampler: base_dataset = streamdata.DataPipeline( - streamdata.SimpleShardList(shardlist), - streamdata.split_by_node if train_mode else streamdata.placeholder(), + streamdata.SimpleShardList(shardlist), streamdata.split_by_node + if train_mode else streamdata.placeholder(), streamdata.split_by_worker, - streamdata.tarfile_to_samples(streamdata.reraise_exception) - ) + streamdata.tarfile_to_samples(streamdata.reraise_exception)) else: base_dataset = streamdata.DataPipeline( streamdata.SimpleShardList(shardlist), streamdata.split_by_worker, - streamdata.tarfile_to_samples(streamdata.reraise_exception) - ) + streamdata.tarfile_to_samples(streamdata.reraise_exception)) self.dataset = base_dataset.append_list( streamdata.audio_tokenize(symbol_table), - streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out), + streamdata.audio_data_filter( + frame_shift=frame_shift, + max_length=maxlen_in, + min_length=minlen_in, + token_max_length=maxlen_out, + token_min_length=minlen_out), streamdata.audio_resample(resample_rate=resample_rate), - streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither), - streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) + streamdata.audio_compute_fbank( + num_mel_bins=num_mel_bins, + frame_length=frame_length, + frame_shift=frame_shift, + dither=dither), + streamdata.audio_spec_aug(**augment_conf) + if train_mode else streamdata.placeholder( + ), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80) streamdata.shuffle(shuffle_size), streamdata.sort(sort_size=sort_size), streamdata.batched(batch_size), streamdata.audio_padding(), - streamdata.audio_cmvn(cmvn_file) - ) + streamdata.audio_cmvn(cmvn_file)) if paddle.__version__ >= '2.3.2': self.loader = streamdata.WebLoader( - self.dataset, - num_workers=self.n_iter_processes, - prefetch_factor = self.prefetch_factor, - batch_size=None - ) + self.dataset, + num_workers=self.n_iter_processes, + prefetch_factor=self.prefetch_factor, + batch_size=None) else: self.loader = streamdata.WebLoader( - self.dataset, - num_workers=self.n_iter_processes, - batch_size=None - ) + self.dataset, + num_workers=self.n_iter_processes, + batch_size=None) def __iter__(self): return self.loader.__iter__() @@ -188,7 +199,9 @@ class StreamDataLoader(): return self.__iter__() def __len__(self): - logger.info("Stream dataloader does not support calculate the length of the dataset") + logger.info( + "Stream dataloader does not support calculate the length of the dataset" + ) return -1 @@ -347,7 +360,7 @@ class DataLoaderFactory(): config['train_mode'] = True elif mode == 'valid': config['manifest'] = config.dev_manifest - config['train_mode'] = False + config['train_mode'] = False elif model == 'test' or mode == 'align': config['manifest'] = config.test_manifest config['train_mode'] = False @@ -358,30 +371,31 @@ class DataLoaderFactory(): config['maxlen_out'] = float('inf') config['dist_sampler'] = False else: - raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") - return StreamDataLoader( - manifest_file=config.manifest, - train_mode=config.train_mode, - unit_type=config.unit_type, - preprocess_conf=config.preprocess_config, - batch_size=config.batch_size, - num_mel_bins=config.feat_dim, - frame_length=config.window_ms, - frame_shift=config.stride_ms, - dither=config.dither, - minlen_in=config.minlen_in, - maxlen_in=config.maxlen_in, - minlen_out=config.minlen_out, - maxlen_out=config.maxlen_out, - resample_rate=config.resample_rate, - shuffle_size=config.shuffle_size, - sort_size=config.sort_size, - n_iter_processes=config.num_workers, - prefetch_factor=config.prefetch_factor, - dist_sampler=config.dist_sampler, - cmvn_file=config.cmvn_file, - vocab_filepath=config.vocab_filepath, + raise KeyError( + "not valid mode type!!, please input one of 'train, valid, test, align'" ) + return StreamDataLoader( + manifest_file=config.manifest, + train_mode=config.train_mode, + unit_type=config.unit_type, + preprocess_conf=config.preprocess_config, + batch_size=config.batch_size, + num_mel_bins=config.feat_dim, + frame_length=config.window_ms, + frame_shift=config.stride_ms, + dither=config.dither, + minlen_in=config.minlen_in, + maxlen_in=config.maxlen_in, + minlen_out=config.minlen_out, + maxlen_out=config.maxlen_out, + resample_rate=config.resample_rate, + shuffle_size=config.shuffle_size, + sort_size=config.sort_size, + n_iter_processes=config.num_workers, + prefetch_factor=config.prefetch_factor, + dist_sampler=config.dist_sampler, + cmvn_file=config.cmvn_file, + vocab_filepath=config.vocab_filepath, ) else: if mode == 'train': config['manifest'] = config.train_manifest @@ -411,7 +425,7 @@ class DataLoaderFactory(): config['train_mode'] = False config['sortagrad'] = False config['batch_size'] = config.get('decode', dict()).get( - 'decode_batch_size', 1) + 'decode_batch_size', 1) config['maxlen_in'] = float('inf') config['maxlen_out'] = float('inf') config['minibatches'] = 0 @@ -427,8 +441,10 @@ class DataLoaderFactory(): config['dist_sampler'] = False config['shortest_first'] = False else: - raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'") - + raise KeyError( + "not valid mode type!!, please input one of 'train, valid, test, align'" + ) + return BatchDataLoader( json_file=config.manifest, train_mode=config.train_mode, @@ -450,4 +466,3 @@ class DataLoaderFactory(): num_encs=config.num_encs, dist_sampler=config.dist_sampler, shortest_first=config.shortest_first) - diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index e86bbedf..e8b61bc0 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -18,7 +18,6 @@ Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recogni """ import time from typing import Dict -from typing import List from typing import Optional from typing import Tuple @@ -26,6 +25,8 @@ import paddle from paddle import jit from paddle import nn +from paddlespeech.audio.utils.tensor_utils import add_sos_eos +from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.modules.cmvn import GlobalCMVN @@ -38,8 +39,6 @@ from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils.log import Log -from paddlespeech.audio.utils.tensor_utils import add_sos_eos -from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ["U2STModel", "U2STInferModel"] @@ -401,8 +400,8 @@ class U2STBaseModel(nn.Layer): xs: paddle.Tensor, offset: int, required_cache_size: int, - att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), - cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]), + att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), + cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]), ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """ Export interface for c++ call, give input chunk xs, and return output from time 0 to current chunk. @@ -435,8 +434,8 @@ class U2STBaseModel(nn.Layer): paddle.Tensor: new conformer cnn cache required for next chunk, with same shape as the original cnn_cache. """ - return self.encoder.forward_chunk( - xs, offset, required_cache_size, att_cache, cnn_cache) + return self.encoder.forward_chunk(xs, offset, required_cache_size, + att_cache, cnn_cache) # @jit.to_static def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: diff --git a/paddlespeech/s2t/modules/align.py b/paddlespeech/s2t/modules/align.py index cacda246..34d79614 100644 --- a/paddlespeech/s2t/modules/align.py +++ b/paddlespeech/s2t/modules/align.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math + import paddle from paddle import nn -import math """ To align the initializer between paddle and torch, the API below are set defalut initializer with priority higger than global initializer. @@ -81,10 +82,18 @@ class Linear(nn.Linear): name=None): if weight_attr is None: if global_init_type == "kaiming_uniform": - weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) if bias_attr is None: if global_init_type == "kaiming_uniform": - bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) super(Linear, self).__init__(in_features, out_features, weight_attr, bias_attr, name) @@ -104,10 +113,18 @@ class Conv1D(nn.Conv1D): data_format='NCL'): if weight_attr is None: if global_init_type == "kaiming_uniform": - weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) if bias_attr is None: if global_init_type == "kaiming_uniform": - bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) super(Conv1D, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, padding_mode, weight_attr, bias_attr, data_format) @@ -128,10 +145,18 @@ class Conv2D(nn.Conv2D): data_format='NCHW'): if weight_attr is None: if global_init_type == "kaiming_uniform": - weight_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) if bias_attr is None: if global_init_type == "kaiming_uniform": - bias_attr = paddle.ParamAttr(initializer=nn.initializer.KaimingUniform(fan_in=None, negative_slope=math.sqrt(5), nonlinearity='leaky_relu')) + bias_attr = paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + fan_in=None, + negative_slope=math.sqrt(5), + nonlinearity='leaky_relu')) super(Conv2D, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, padding_mode, weight_attr, bias_attr, data_format) diff --git a/paddlespeech/s2t/modules/initializer.py b/paddlespeech/s2t/modules/initializer.py index cdcf2e05..6eae5713 100644 --- a/paddlespeech/s2t/modules/initializer.py +++ b/paddlespeech/s2t/modules/initializer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np + class DefaultInitializerContext(object): """ diff --git a/paddlespeech/server/engine/asr/online/ctc_endpoint.py b/paddlespeech/server/engine/asr/online/ctc_endpoint.py index b87dbe80..1b8ad1cb 100644 --- a/paddlespeech/server/engine/asr/online/ctc_endpoint.py +++ b/paddlespeech/server/engine/asr/online/ctc_endpoint.py @@ -102,8 +102,10 @@ class OnlineCTCEndpoint: assert self.num_frames_decoded >= self.trailing_silence_frames assert self.frame_shift_in_ms > 0 - - decoding_something = (self.num_frames_decoded > self.trailing_silence_frames) and decoding_something + + decoding_something = ( + self.num_frames_decoded > self.trailing_silence_frames + ) and decoding_something utterance_length = self.num_frames_decoded * self.frame_shift_in_ms trailing_silence = self.trailing_silence_frames * self.frame_shift_in_ms diff --git a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py index ab4f1130..6daae5be 100644 --- a/paddlespeech/server/engine/asr/online/onnx/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/onnx/asr_engine.py @@ -21,12 +21,12 @@ import paddle from numpy import float32 from yacs.config import CfgNode +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder -from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.server.engine.base_engine import BaseEngine from paddlespeech.server.utils import onnx_infer diff --git a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py index 182e6418..0fd5d1bc 100644 --- a/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/paddleinference/asr_engine.py @@ -21,10 +21,10 @@ import paddle from numpy import float32 from yacs.config import CfgNode +from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.log import logger from paddlespeech.resource import CommonTaskResource -from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.utils.utility import UpdateConfig diff --git a/paddlespeech/server/engine/asr/python/asr_engine.py b/paddlespeech/server/engine/asr/python/asr_engine.py index 9ce05d97..e297e5c2 100644 --- a/paddlespeech/server/engine/asr/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/python/asr_engine.py @@ -66,12 +66,14 @@ class ASREngine(BaseEngine): ) logger.error(e) return False - - self.executor._init_from_path( - model_type = self.config.model, lang = self.config.lang, sample_rate = self.config.sample_rate, - cfg_path = self.config.cfg_path, decode_method = self.config.decode_method, - ckpt_path = self.config.ckpt_path) + self.executor._init_from_path( + model_type=self.config.model, + lang=self.config.lang, + sample_rate=self.config.sample_rate, + cfg_path=self.config.cfg_path, + decode_method=self.config.decode_method, + ckpt_path=self.config.ckpt_path) logger.info("Initialize ASR server engine successfully on device: %s." % (self.device)) diff --git a/paddlespeech/t2s/datasets/sampler.py b/paddlespeech/t2s/datasets/sampler.py index a69bc860..3c97d1dc 100644 --- a/paddlespeech/t2s/datasets/sampler.py +++ b/paddlespeech/t2s/datasets/sampler.py @@ -1,8 +1,9 @@ -import paddle import math + import numpy as np from paddle.io import BatchSampler + class ErnieSATSampler(BatchSampler): """Sampler that restricts data loading to a subset of the dataset. In such case, each process can pass a DistributedBatchSampler instance @@ -110,8 +111,8 @@ class ErnieSATSampler(BatchSampler): subsampled_indices.extend(indices[i:i + self.batch_size]) indices = indices[len(indices) - last_batch_size:] - subsampled_indices.extend(indices[ - self.local_rank * last_local_batch_size:( + subsampled_indices.extend( + indices[self.local_rank * last_local_batch_size:( self.local_rank + 1) * last_local_batch_size]) return subsampled_indices diff --git a/paddlespeech/t2s/exps/ernie_sat/train.py b/paddlespeech/t2s/exps/ernie_sat/train.py index af653ef8..75a666bb 100644 --- a/paddlespeech/t2s/exps/ernie_sat/train.py +++ b/paddlespeech/t2s/exps/ernie_sat/train.py @@ -25,7 +25,6 @@ from paddle import DataParallel from paddle import distributed as dist from paddle import nn from paddle.io import DataLoader -from paddle.io import DistributedBatchSampler from paddle.optimizer import Adam from yacs.config import CfgNode diff --git a/paddlespeech/t2s/exps/ernie_sat/utils.py b/paddlespeech/t2s/exps/ernie_sat/utils.py index 9169efa3..6805e513 100644 --- a/paddlespeech/t2s/exps/ernie_sat/utils.py +++ b/paddlespeech/t2s/exps/ernie_sat/utils.py @@ -11,32 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import os from pathlib import Path from typing import Dict from typing import List from typing import Union -import os import numpy as np import paddle import yaml from yacs.config import CfgNode -import hashlib - from paddlespeech.t2s.exps.syn_utils import get_am_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference + def _get_user(): return os.path.expanduser('~').split('/')[-1] + def str2md5(string): md5_val = hashlib.md5(string.encode('utf8')).hexdigest() return md5_val -def get_tmp_name(text:str): + +def get_tmp_name(text: str): return _get_user() + '_' + str(os.getpid()) + '_' + str2md5(text) + def get_dict(dictfile: str): word2phns_dict = {} with open(dictfile, 'r') as fid: diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index c8eb1c64..15d8dfb7 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -298,8 +298,8 @@ def am_to_static(am_inference, am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] if am_name == 'fastspeech2': - if am_dataset in {"aishell3", "vctk", "mix" - } and speaker_dict is not None: + if am_dataset in {"aishell3", "vctk", + "mix"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ @@ -311,8 +311,8 @@ def am_to_static(am_inference, am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)]) elif am_name == 'speedyspeech': - if am_dataset in {"aishell3", "vctk", "mix" - } and speaker_dict is not None: + if am_dataset in {"aishell3", "vctk", + "mix"} and speaker_dict is not None: am_inference = jit.to_static( am_inference, input_spec=[ diff --git a/paddlespeech/t2s/frontend/g2pw/__init__.py b/paddlespeech/t2s/frontend/g2pw/__init__.py index 6e1ee0db..0eaeee5d 100644 --- a/paddlespeech/t2s/frontend/g2pw/__init__.py +++ b/paddlespeech/t2s/frontend/g2pw/__init__.py @@ -1,2 +1 @@ from paddlespeech.t2s.frontend.g2pw.onnx_api import G2PWOnnxConverter - diff --git a/paddlespeech/t2s/frontend/mix_frontend.py b/paddlespeech/t2s/frontend/mix_frontend.py index a681445c..101a1e50 100644 --- a/paddlespeech/t2s/frontend/mix_frontend.py +++ b/paddlespeech/t2s/frontend/mix_frontend.py @@ -61,8 +61,11 @@ class MixFrontend(): return False def is_end(self, before_char, after_char) -> bool: - if ((self.is_alphabet(before_char) or before_char == " ") and - (self.is_alphabet(after_char) or after_char == " ")): + flag = 0 + for char in (before_char, after_char): + if self.is_alphabet(char) or char == " ": + flag += 1 + if flag == 2: return True else: return False diff --git a/paddlespeech/t2s/training/updaters/standard_updater.py b/paddlespeech/t2s/training/updaters/standard_updater.py index 668d2fc6..6d3aa709 100644 --- a/paddlespeech/t2s/training/updaters/standard_updater.py +++ b/paddlespeech/t2s/training/updaters/standard_updater.py @@ -24,10 +24,11 @@ from paddle.nn import Layer from paddle.optimizer import Optimizer from timer import timer +from paddlespeech.t2s.datasets.sampler import ErnieSATSampler from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.updater import UpdaterBase from paddlespeech.t2s.training.updater import UpdaterState -from paddlespeech.t2s.datasets.sampler import ErnieSATSampler + class StandardUpdater(UpdaterBase): """An example of over-simplification. Things may not be that simple, but diff --git a/setup.py b/setup.py index 079803b7..fac9e120 100644 --- a/setup.py +++ b/setup.py @@ -77,12 +77,7 @@ base = [ "pybind11", ] -server = [ - "fastapi", - "uvicorn", - "pattern_singleton", - "websockets" -] +server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"] requirements = { "install": @@ -330,4 +325,4 @@ setup_info = dict( }) with version_info(): - setup(**setup_info,include_package_data=True) + setup(**setup_info, include_package_data=True) diff --git a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py index 4426d1be..c53e9ec9 100755 --- a/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py +++ b/speechx/examples/ds2_ol/onnx/local/onnx_infer_shape.py @@ -490,18 +490,10 @@ class SymbolicShapeInference: def _onnx_infer_single_node(self, node): # skip onnx shape inference for some ops, as they are handled in _infer_* skip_infer = node.op_type in [ - 'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \ - # contrib ops - - - - - 'Attention', 'BiasGelu', \ - 'EmbedLayerNormalization', \ - 'FastGelu', 'Gelu', 'LayerNormalization', \ - 'LongformerAttention', \ - 'SkipLayerNormalization', \ - 'PythonOp' + 'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', 'Attention', + 'BiasGelu', 'EmbedLayerNormalization', 'FastGelu', 'Gelu', + 'LayerNormalization', 'LongformerAttention', + 'SkipLayerNormalization', 'PythonOp' ] if not skip_infer: @@ -514,8 +506,8 @@ class SymbolicShapeInference: if (get_opset(self.out_mp_) >= 9) and node.op_type in ['Unsqueeze']: initializers = [ self.initializers_[name] for name in node.input - if (name in self.initializers_ and - name not in self.graph_inputs_) + if (name in self.initializers_ and name not in + self.graph_inputs_) ] # run single node inference with self.known_vi_ shapes @@ -601,8 +593,8 @@ class SymbolicShapeInference: for o in symbolic_shape_inference.out_mp_.graph.output ] subgraph_new_symbolic_dims = set([ - d for s in subgraph_shapes if s for d in s - if type(d) == str and not d in self.symbolic_dims_ + d for s in subgraph_shapes + if s for d in s if type(d) == str and not d in self.symbolic_dims_ ]) new_dims = {} for d in subgraph_new_symbolic_dims: @@ -729,8 +721,9 @@ class SymbolicShapeInference: for d, s in zip(sympy_shape[-rank:], strides) ] total_pads = [ - max(0, (k - s) if r == 0 else (k - r)) for k, s, r in - zip(effective_kernel_shape, strides, residual) + max(0, (k - s) if r == 0 else (k - r)) + for k, s, r in zip(effective_kernel_shape, strides, + residual) ] except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational total_pads = [ @@ -1276,8 +1269,9 @@ class SymbolicShapeInference: if pads is not None: assert len(pads) == 2 * rank new_sympy_shape = [ - d + pad_up + pad_down for d, pad_up, pad_down in - zip(sympy_shape, pads[:rank], pads[rank:]) + d + pad_up + pad_down + for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[ + rank:]) ] self._update_computed_dims(new_sympy_shape) else: @@ -1590,8 +1584,8 @@ class SymbolicShapeInference: scales = list(scales) new_sympy_shape = [ sympy.simplify(sympy.floor(d * (end - start) * scale)) - for d, start, end, scale in - zip(input_sympy_shape, roi_start, roi_end, scales) + for d, start, end, scale in zip(input_sympy_shape, + roi_start, roi_end, scales) ] self._update_computed_dims(new_sympy_shape) else: @@ -2204,8 +2198,9 @@ class SymbolicShapeInference: # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate sorted_nodes = [] sorted_known_vi = set([ - i.name for i in list(self.out_mp_.graph.input) + - list(self.out_mp_.graph.initializer) + i.name + for i in list(self.out_mp_.graph.input) + list( + self.out_mp_.graph.initializer) ]) if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): # Loop/Scan will have some graph output in graph inputs, so don't do topological sort