pull/1945/head
Hui Zhang 2 years ago
parent 943272385a
commit c15278ed80

@ -26,9 +26,8 @@ def get_audios(path):
""" """
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [ return [
item item for sublist in [[os.path.join(dir, file) for file in files]
for sublist in [[os.path.join(dir, file) for file in files] for dir, _, files in list(os.walk(path))]
for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats for item in sublist if os.path.splitext(item)[1] in supported_formats
] ]

@ -62,4 +62,4 @@ I0513 10:58:13.884493 41768 feature_cache.h:52] set finished
I0513 10:58:24.247171 41768 paddle_nnet.h:76] Tensor neml: 10240 I0513 10:58:24.247171 41768 paddle_nnet.h:76] Tensor neml: 10240
I0513 10:58:24.247249 41768 paddle_nnet.h:76] Tensor neml: 10240 I0513 10:58:24.247249 41768 paddle_nnet.h:76] Tensor neml: 10240
LOG ([5.5.544~2-f21d7]:main():decoder/recognizer_test_main.cc:90) the result of case_10 is 五月十二日二十二点三十六分加班打车回家四十一元 LOG ([5.5.544~2-f21d7]:main():decoder/recognizer_test_main.cc:90) the result of case_10 is 五月十二日二十二点三十六分加班打车回家四十一元
``` ```

@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
#!/usr/bin/python #!/usr/bin/python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}' # script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
import argparse import argparse
import asyncio import asyncio
import codecs import codecs

@ -92,5 +92,3 @@ server 的 demo [streaming_asr_server](https://github.com/PaddlePaddle/Paddle
## 4. 快速开始 ## 4. 快速开始
关于如果使用 PP-ASR可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单**、**中等**、**困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。 关于如果使用 PP-ASR可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单**、**中等**、**困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。

@ -24,11 +24,11 @@ from typing import Any
from typing import Dict from typing import Dict
import paddle import paddle
import paddleaudio
import requests import requests
import yaml import yaml
from paddle.framework import load from paddle.framework import load
import paddleaudio
from . import download from . import download
from .entry import commands from .entry import commands
try: try:

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -153,8 +153,7 @@ class PaddleASRConnectionHanddler:
spectrum = self.collate_fn_test._normalizer.apply(spectrum) spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment # spectrum augment
feat = self.collate_fn_test.augmentation.transform_feature( feat = self.collate_fn_test.augmentation.transform_feature(spectrum)
spectrum)
# audio_len is frame num # audio_len is frame num
frame_num = feat.shape[0] frame_num = feat.shape[0]
@ -189,14 +188,16 @@ class PaddleASRConnectionHanddler:
assert samples.ndim == 1 assert samples.ndim == 1
self.num_samples += samples.shape[0] self.num_samples += samples.shape[0]
logger.info(f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}") logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples, # self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples # include the original remained_wav and this package samples
if self.remained_wav is None: if self.remained_wav is None:
self.remained_wav = samples self.remained_wav = samples
else: else:
assert self.remained_wav.ndim == 1 # (T,) assert self.remained_wav.ndim == 1 # (T,)
self.remained_wav = np.concatenate([self.remained_wav, samples]) self.remained_wav = np.concatenate([self.remained_wav, samples])
logger.info( logger.info(
f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}" f"The concatenation of remain and now audio samples length is: {self.remained_wav.shape}"
@ -216,8 +217,8 @@ class PaddleASRConnectionHanddler:
if self.cached_feat is None: if self.cached_feat is None:
self.cached_feat = x_chunk self.cached_feat = x_chunk
else: else:
assert (len(x_chunk.shape) == 3) # (B,T,D) assert (len(x_chunk.shape) == 3) # (B,T,D)
assert (len(self.cached_feat.shape) == 3) # (B,T,D) assert (len(self.cached_feat.shape) == 3) # (B,T,D)
self.cached_feat = paddle.concat( self.cached_feat = paddle.concat(
[self.cached_feat, x_chunk], axis=1) [self.cached_feat, x_chunk], axis=1)
@ -234,18 +235,16 @@ class PaddleASRConnectionHanddler:
# update remained wav # update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:] self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info( logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
) )
logger.info( logger.info(
f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}" f"After extract feat, the cached remain the audio samples: {self.remained_wav.shape}"
) )
logger.info(f"global samples: {self.num_samples}") logger.info(f"global samples: {self.num_samples}")
logger.info(f"global frames: {self.num_frames}") logger.info(f"global frames: {self.num_frames}")
else: else:
raise ValueError(f"not supported: {self.model_type}") raise ValueError(f"not supported: {self.model_type}")
def reset(self): def reset(self):
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
@ -263,12 +262,11 @@ class PaddleASRConnectionHanddler:
# global sample and frame step # global sample and frame step
self.num_samples = 0 self.num_samples = 0
self.num_frames = 0 self.num_frames = 0
# cache for audio and feat # cache for audio and feat
self.remained_wav = None self.remained_wav = None
self.cached_feat = None self.cached_feat = None
# partial/ending decoding results # partial/ending decoding results
self.result_transcripts = [''] self.result_transcripts = ['']
@ -280,17 +278,16 @@ class PaddleASRConnectionHanddler:
self.conformer_cnn_cache = None self.conformer_cnn_cache = None
self.encoder_out = None self.encoder_out = None
# conformer decoding state # conformer decoding state
self.chunk_num = 0 # globa decoding chunk num self.chunk_num = 0 # globa decoding chunk num
self.offset = 0 # global offset in decoding frame unit self.offset = 0 # global offset in decoding frame unit
self.hyps = [] self.hyps = []
# token timestamp result # token timestamp result
self.word_time_stamp = [] self.word_time_stamp = []
# one best timestamp viterbi prob is large. # one best timestamp viterbi prob is large.
self.time_stamp = [] self.time_stamp = []
def decode(self, is_finished=False): def decode(self, is_finished=False):
"""advance decoding """advance decoding
@ -307,7 +304,7 @@ class PaddleASRConnectionHanddler:
decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit decoding_chunk_size = 1 # decoding chunk size = 1. int decoding frame unit
context = 7 # context=7, in audio frame unit context = 7 # context=7, in audio frame unit
subsampling = 4 # subsampling=4, in audio frame unit subsampling = 4 # subsampling=4, in audio frame unit
cached_feature_num = context - subsampling cached_feature_num = context - subsampling
# decoding window for model, in audio frame unit # decoding window for model, in audio frame unit
decoding_window = (decoding_chunk_size - 1) * subsampling + context decoding_window = (decoding_chunk_size - 1) * subsampling + context
@ -373,7 +370,6 @@ class PaddleASRConnectionHanddler:
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")
@paddle.no_grad() @paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens): def decode_one_chunk(self, x_chunk, x_chunk_lens):
"""forward one chunk frames """forward one chunk frames
@ -425,10 +421,11 @@ class PaddleASRConnectionHanddler:
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}") logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0] return trans_best[0]
@paddle.no_grad() @paddle.no_grad()
def advance_decoding(self, is_finished=False): def advance_decoding(self, is_finished=False):
logger.info("Conformer/Transformer: start to decode with advanced_decoding method") logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg = self.ctc_decode_config cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit # cur chunk size, in decoding frame unit
@ -563,7 +560,6 @@ class PaddleASRConnectionHanddler:
""" """
return self.word_time_stamp return self.word_time_stamp
@paddle.no_grad() @paddle.no_grad()
def rescoring(self): def rescoring(self):
"""Second-Pass Decoding, """Second-Pass Decoding,
@ -572,9 +568,11 @@ class PaddleASRConnectionHanddler:
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
logger.info("deepspeech2 not support rescoring decoding.") logger.info("deepspeech2 not support rescoring decoding.")
return return
if "attention_rescoring" != self.ctc_decode_config.decoding_method: if "attention_rescoring" != self.ctc_decode_config.decoding_method:
logger.info(f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring") logger.info(
f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring"
)
return return
logger.info("rescoring the final result") logger.info("rescoring the final result")
@ -605,7 +603,8 @@ class PaddleASRConnectionHanddler:
hyp_content, place=self.device, dtype=paddle.long) hyp_content, place=self.device, dtype=paddle.long)
hyp_list.append(hyp_content) hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, batch_first=True, padding_value=self.model.ignore_id) hyps_pad = pad_sequence(
hyp_list, batch_first=True, padding_value=self.model.ignore_id)
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device, [len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
@ -689,12 +688,11 @@ class PaddleASRConnectionHanddler:
"ed": end "ed": end
}) })
# logger.info(f"{word_time_stamp[-1]}") # logger.info(f"{word_time_stamp[-1]}")
self.word_time_stamp = word_time_stamp self.word_time_stamp = word_time_stamp
logger.info(f"word time stamp: {self.word_time_stamp}") logger.info(f"word time stamp: {self.word_time_stamp}")
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -741,7 +739,7 @@ class ASRServerExecutor(ASRExecutor):
self.am_model = os.path.abspath(am_model) self.am_model = os.path.abspath(am_model)
self.am_params = os.path.abspath(am_params) self.am_params = os.path.abspath(am_params)
self.res_path = os.path.dirname( self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path))) os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path) logger.info(self.cfg_path)
logger.info(self.am_model) logger.info(self.am_model)
@ -855,7 +853,7 @@ class ASRServerExecutor(ASRExecutor):
self.transformer_decode_reset() self.transformer_decode_reset()
else: else:
raise ValueError(f"Not support: {model_type}") raise ValueError(f"Not support: {model_type}")
return True return True

@ -17,12 +17,12 @@ from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.restful.acs_api import router as acs_router
from paddlespeech.server.restful.asr_api import router as asr_router from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.text_api import router as text_router from paddlespeech.server.restful.text_api import router as text_router
from paddlespeech.server.restful.tts_api import router as tts_router from paddlespeech.server.restful.tts_api import router as tts_router
from paddlespeech.server.restful.vector_api import router as vec_router from paddlespeech.server.restful.vector_api import router as vec_router
from paddlespeech.server.restful.acs_api import router as acs_router
_router = APIRouter() _router = APIRouter()

@ -248,7 +248,7 @@ class ASRHttpHandler:
} }
res = requests.post(url=self.url, data=json.dumps(data)) res = requests.post(url=self.url, data=json.dumps(data))
return res.json() return res.json()

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class Frame(object): class Frame(object):
"""Represents a "frame" of audio data.""" """Represents a "frame" of audio data."""
@ -45,7 +46,7 @@ class ChunkBuffer(object):
self.shift_ms = shift_ms self.shift_ms = shift_ms
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.sample_width = sample_width # int16 = 2; float32 = 4 self.sample_width = sample_width # int16 = 2; float32 = 4
self.window_sec = float((self.window_n - 1) * self.shift_ms + self.window_sec = float((self.window_n - 1) * self.shift_ms +
self.window_ms) / 1000.0 self.window_ms) / 1000.0
self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0) self.shift_sec = float(self.shift_n * self.shift_ms / 1000.0)
@ -77,8 +78,8 @@ class ChunkBuffer(object):
offset = 0 offset = 0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], self.timestamp, yield Frame(audio[offset:offset + self.window_bytes],
self.window_sec) self.timestamp, self.window_sec)
self.timestamp += self.shift_sec self.timestamp += self.shift_sec
offset += self.shift_bytes offset += self.shift_bytes

@ -176,7 +176,10 @@ def main():
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.")
parser.add_argument( parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") "--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()

@ -188,7 +188,10 @@ def main():
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument( parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") "--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu")

@ -36,4 +36,4 @@ def repeat(N, fn):
Returns: Returns:
MultiSequential: Repeated model instance. MultiSequential: Repeated model instance.
""" """
return MultiSequential(*[fn(n) for n in range(N)]) return MultiSequential(* [fn(n) for n in range(N)])

@ -98,7 +98,6 @@ requirements = {
} }
def check_call(cmd: str, shell=False, executable=None): def check_call(cmd: str, shell=False, executable=None):
try: try:
sp.check_call( sp.check_call(
@ -112,12 +111,13 @@ def check_call(cmd: str, shell=False, executable=None):
file=sys.stderr) file=sys.stderr)
raise e raise e
def check_output(cmd: str, shell=False): def check_output(cmd: str, shell=False):
try: try:
out_bytes = sp.check_output(cmd.split()) out_bytes = sp.check_output(cmd.split())
except sp.CalledProcessError as e: except sp.CalledProcessError as e:
out_bytes = e.output # Output generated before error out_bytes = e.output # Output generated before error
code = e.returncode # Return code code = e.returncode # Return code
print( print(
f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:", f"{__file__}:{inspect.currentframe().f_lineno}: CMD: {cmd}, Error:",
out_bytes, out_bytes,
@ -146,6 +146,7 @@ def _remove(files: str):
for f in files: for f in files:
f.unlink() f.unlink()
################################# Install ################################## ################################# Install ##################################
@ -308,6 +309,5 @@ setup_info = dict(
] ]
}) })
with version_info(): with version_info():
setup(**setup_info) setup(**setup_info)

@ -20,7 +20,6 @@ of each audio file in the data set.
""" """
import argparse import argparse
import codecs import codecs
import json
import os import os
from pathlib import Path from pathlib import Path
@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix):
duration = float(len(audio_data) / samplerate) duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id] text = transcript_dict[audio_id]
json_lines.append(audio_path) json_lines.append(audio_path)
reference_lines.append(str(total_num+1) + "\t" + text) reference_lines.append(str(total_num + 1) + "\t" + text)
total_sec += duration total_sec += duration
total_text += len(text) total_text += len(text)
@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix):
manifest_dir = os.path.dirname(manifest_path_prefix) manifest_dir = os.path.dirname(manifest_path_prefix)
def prepare_dataset(url, md5sum, target_dir, manifest_path=None): def prepare_dataset(url, md5sum, target_dir, manifest_path=None):
"""Download, unpack and create manifest file.""" """Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell') data_dir = os.path.join(target_dir, 'data_aishell')

Loading…
Cancel
Save