pull/1945/head
Hui Zhang 3 years ago
parent 943272385a
commit c15278ed80

@ -26,8 +26,7 @@ def get_audios(path):
""" """
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"] supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [ return [
item item for sublist in [[os.path.join(dir, file) for file in files]
for sublist in [[os.path.join(dir, file) for file in files]
for dir, _, files in list(os.walk(path))] for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats for item in sublist if os.path.splitext(item)[1] in supported_formats
] ]

@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
#!/usr/bin/python #!/usr/bin/python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}' # script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
import argparse import argparse
import asyncio import asyncio
import codecs import codecs

@ -92,5 +92,3 @@ server 的 demo [streaming_asr_server](https://github.com/PaddlePaddle/Paddle
## 4. 快速开始 ## 4. 快速开始
关于如果使用 PP-ASR可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单**、**中等**、**困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。 关于如果使用 PP-ASR可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单**、**中等**、**困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。

@ -24,11 +24,11 @@ from typing import Any
from typing import Dict from typing import Dict
import paddle import paddle
import paddleaudio
import requests import requests
import yaml import yaml
from paddle.framework import load from paddle.framework import load
import paddleaudio
from . import download from . import download
from .entry import commands from .entry import commands
try: try:

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -153,8 +153,7 @@ class PaddleASRConnectionHanddler:
spectrum = self.collate_fn_test._normalizer.apply(spectrum) spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment # spectrum augment
feat = self.collate_fn_test.augmentation.transform_feature( feat = self.collate_fn_test.augmentation.transform_feature(spectrum)
spectrum)
# audio_len is frame num # audio_len is frame num
frame_num = feat.shape[0] frame_num = feat.shape[0]
@ -189,7 +188,9 @@ class PaddleASRConnectionHanddler:
assert samples.ndim == 1 assert samples.ndim == 1
self.num_samples += samples.shape[0] self.num_samples += samples.shape[0]
logger.info(f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}") logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples, # self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples # include the original remained_wav and this package samples
@ -234,7 +235,6 @@ class PaddleASRConnectionHanddler:
# update remained wav # update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:] self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info( logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}" f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
) )
@ -246,7 +246,6 @@ class PaddleASRConnectionHanddler:
else: else:
raise ValueError(f"not supported: {self.model_type}") raise ValueError(f"not supported: {self.model_type}")
def reset(self): def reset(self):
if "deepspeech2" in self.model_type: if "deepspeech2" in self.model_type:
# for deepspeech2 # for deepspeech2
@ -268,7 +267,6 @@ class PaddleASRConnectionHanddler:
self.remained_wav = None self.remained_wav = None
self.cached_feat = None self.cached_feat = None
# partial/ending decoding results # partial/ending decoding results
self.result_transcripts = [''] self.result_transcripts = ['']
@ -290,7 +288,6 @@ class PaddleASRConnectionHanddler:
# one best timestamp viterbi prob is large. # one best timestamp viterbi prob is large.
self.time_stamp = [] self.time_stamp = []
def decode(self, is_finished=False): def decode(self, is_finished=False):
"""advance decoding """advance decoding
@ -373,7 +370,6 @@ class PaddleASRConnectionHanddler:
else: else:
raise Exception("invalid model name") raise Exception("invalid model name")
@paddle.no_grad() @paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens): def decode_one_chunk(self, x_chunk, x_chunk_lens):
"""forward one chunk frames """forward one chunk frames
@ -425,10 +421,11 @@ class PaddleASRConnectionHanddler:
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}") logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0] return trans_best[0]
@paddle.no_grad() @paddle.no_grad()
def advance_decoding(self, is_finished=False): def advance_decoding(self, is_finished=False):
logger.info("Conformer/Transformer: start to decode with advanced_decoding method") logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg = self.ctc_decode_config cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit # cur chunk size, in decoding frame unit
@ -563,7 +560,6 @@ class PaddleASRConnectionHanddler:
""" """
return self.word_time_stamp return self.word_time_stamp
@paddle.no_grad() @paddle.no_grad()
def rescoring(self): def rescoring(self):
"""Second-Pass Decoding, """Second-Pass Decoding,
@ -574,7 +570,9 @@ class PaddleASRConnectionHanddler:
return return
if "attention_rescoring" != self.ctc_decode_config.decoding_method: if "attention_rescoring" != self.ctc_decode_config.decoding_method:
logger.info(f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring") logger.info(
f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring"
)
return return
logger.info("rescoring the final result") logger.info("rescoring the final result")
@ -605,7 +603,8 @@ class PaddleASRConnectionHanddler:
hyp_content, place=self.device, dtype=paddle.long) hyp_content, place=self.device, dtype=paddle.long)
hyp_list.append(hyp_content) hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, batch_first=True, padding_value=self.model.ignore_id) hyps_pad = pad_sequence(
hyp_list, batch_first=True, padding_value=self.model.ignore_id)
hyps_lens = paddle.to_tensor( hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device, [len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
@ -694,7 +693,6 @@ class PaddleASRConnectionHanddler:
logger.info(f"word time stamp: {self.word_time_stamp}") logger.info(f"word time stamp: {self.word_time_stamp}")
class ASRServerExecutor(ASRExecutor): class ASRServerExecutor(ASRExecutor):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

@ -17,12 +17,12 @@ from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.server.restful.acs_api import router as acs_router
from paddlespeech.server.restful.asr_api import router as asr_router from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.text_api import router as text_router from paddlespeech.server.restful.text_api import router as text_router
from paddlespeech.server.restful.tts_api import router as tts_router from paddlespeech.server.restful.tts_api import router as tts_router
from paddlespeech.server.restful.vector_api import router as vec_router from paddlespeech.server.restful.vector_api import router as vec_router
from paddlespeech.server.restful.acs_api import router as acs_router
_router = APIRouter() _router = APIRouter()

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class Frame(object): class Frame(object):
"""Represents a "frame" of audio data.""" """Represents a "frame" of audio data."""
@ -77,8 +78,8 @@ class ChunkBuffer(object):
offset = 0 offset = 0
while offset + self.window_bytes <= len(audio): while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], self.timestamp, yield Frame(audio[offset:offset + self.window_bytes],
self.window_sec) self.timestamp, self.window_sec)
self.timestamp += self.shift_sec self.timestamp += self.shift_sec
offset += self.shift_bytes offset += self.shift_bytes

@ -176,7 +176,10 @@ def main():
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.")
parser.add_argument( parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") "--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()

@ -188,7 +188,10 @@ def main():
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument( parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.") "--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu")

@ -36,4 +36,4 @@ def repeat(N, fn):
Returns: Returns:
MultiSequential: Repeated model instance. MultiSequential: Repeated model instance.
""" """
return MultiSequential(*[fn(n) for n in range(N)]) return MultiSequential(* [fn(n) for n in range(N)])

@ -98,7 +98,6 @@ requirements = {
} }
def check_call(cmd: str, shell=False, executable=None): def check_call(cmd: str, shell=False, executable=None):
try: try:
sp.check_call( sp.check_call(
@ -112,6 +111,7 @@ def check_call(cmd: str, shell=False, executable=None):
file=sys.stderr) file=sys.stderr)
raise e raise e
def check_output(cmd: str, shell=False): def check_output(cmd: str, shell=False):
try: try:
out_bytes = sp.check_output(cmd.split()) out_bytes = sp.check_output(cmd.split())
@ -146,6 +146,7 @@ def _remove(files: str):
for f in files: for f in files:
f.unlink() f.unlink()
################################# Install ################################## ################################# Install ##################################
@ -308,6 +309,5 @@ setup_info = dict(
] ]
}) })
with version_info(): with version_info():
setup(**setup_info) setup(**setup_info)

@ -20,7 +20,6 @@ of each audio file in the data set.
""" """
import argparse import argparse
import codecs import codecs
import json
import os import os
from pathlib import Path from pathlib import Path
@ -89,7 +88,7 @@ def create_manifest(data_dir, manifest_path_prefix):
duration = float(len(audio_data) / samplerate) duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id] text = transcript_dict[audio_id]
json_lines.append(audio_path) json_lines.append(audio_path)
reference_lines.append(str(total_num+1) + "\t" + text) reference_lines.append(str(total_num + 1) + "\t" + text)
total_sec += duration total_sec += duration
total_text += len(text) total_text += len(text)
@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix):
manifest_dir = os.path.dirname(manifest_path_prefix) manifest_dir = os.path.dirname(manifest_path_prefix)
def prepare_dataset(url, md5sum, target_dir, manifest_path=None): def prepare_dataset(url, md5sum, target_dir, manifest_path=None):
"""Download, unpack and create manifest file.""" """Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell') data_dir = os.path.join(target_dir, 'data_aishell')

Loading…
Cancel
Save