pull/1945/head
Hui Zhang 3 years ago
parent 943272385a
commit c15278ed80

@ -26,8 +26,7 @@ def get_audios(path):
"""
supported_formats = [".wav", ".mp3", ".ogg", ".flac", ".m4a"]
return [
item
for sublist in [[os.path.join(dir, file) for file in files]
item for sublist in [[os.path.join(dir, file) for file in files]
for dir, _, files in list(os.walk(path))]
for item in sublist if os.path.splitext(item)[1] in supported_formats
]

@ -13,9 +13,7 @@
# limitations under the License.
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# script for calc RTF: grep -rn RTF log.txt | awk '{print $NF}' | awk -F "=" '{sum += $NF} END {print "all time",sum, "audio num", NR, "RTF", sum/NR}'
import argparse
import asyncio
import codecs

@ -92,5 +92,3 @@ server 的 demo [streaming_asr_server](https://github.com/PaddlePaddle/Paddle
## 4. 快速开始
关于如果使用 PP-ASR可以看这里的 [install](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/docs/source/install_cn.md),其中提供了 **简单**、**中等**、**困难** 三种安装方式。如果想体验 paddlespeech 的推理功能,可以用 **简单** 安装方式。

@ -24,11 +24,11 @@ from typing import Any
from typing import Dict
import paddle
import paddleaudio
import requests
import yaml
from paddle.framework import load
import paddleaudio
from . import download
from .entry import commands
try:

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -153,8 +153,7 @@ class PaddleASRConnectionHanddler:
spectrum = self.collate_fn_test._normalizer.apply(spectrum)
# spectrum augment
feat = self.collate_fn_test.augmentation.transform_feature(
spectrum)
feat = self.collate_fn_test.augmentation.transform_feature(spectrum)
# audio_len is frame num
frame_num = feat.shape[0]
@ -189,7 +188,9 @@ class PaddleASRConnectionHanddler:
assert samples.ndim == 1
self.num_samples += samples.shape[0]
logger.info(f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}")
logger.info(
f"This package receive {samples.shape[0]} pcm data. Global samples:{self.num_samples}"
)
# self.reamined_wav stores all the samples,
# include the original remained_wav and this package samples
@ -234,7 +235,6 @@ class PaddleASRConnectionHanddler:
# update remained wav
self.remained_wav = self.remained_wav[self.n_shift * num_frames:]
logger.info(
f"process the audio feature success, the cached feat shape: {self.cached_feat.shape}"
)
@ -246,7 +246,6 @@ class PaddleASRConnectionHanddler:
else:
raise ValueError(f"not supported: {self.model_type}")
def reset(self):
if "deepspeech2" in self.model_type:
# for deepspeech2
@ -268,7 +267,6 @@ class PaddleASRConnectionHanddler:
self.remained_wav = None
self.cached_feat = None
# partial/ending decoding results
self.result_transcripts = ['']
@ -290,7 +288,6 @@ class PaddleASRConnectionHanddler:
# one best timestamp viterbi prob is large.
self.time_stamp = []
def decode(self, is_finished=False):
"""advance decoding
@ -373,7 +370,6 @@ class PaddleASRConnectionHanddler:
else:
raise Exception("invalid model name")
@paddle.no_grad()
def decode_one_chunk(self, x_chunk, x_chunk_lens):
"""forward one chunk frames
@ -425,10 +421,11 @@ class PaddleASRConnectionHanddler:
logger.info(f"decode one best result for deepspeech2: {trans_best[0]}")
return trans_best[0]
@paddle.no_grad()
def advance_decoding(self, is_finished=False):
logger.info("Conformer/Transformer: start to decode with advanced_decoding method")
logger.info(
"Conformer/Transformer: start to decode with advanced_decoding method"
)
cfg = self.ctc_decode_config
# cur chunk size, in decoding frame unit
@ -563,7 +560,6 @@ class PaddleASRConnectionHanddler:
"""
return self.word_time_stamp
@paddle.no_grad()
def rescoring(self):
"""Second-Pass Decoding,
@ -574,7 +570,9 @@ class PaddleASRConnectionHanddler:
return
if "attention_rescoring" != self.ctc_decode_config.decoding_method:
logger.info(f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring")
logger.info(
f"decoding method not match: {self.ctc_decode_config.decoding_method}, need attention_rescoring"
)
return
logger.info("rescoring the final result")
@ -605,7 +603,8 @@ class PaddleASRConnectionHanddler:
hyp_content, place=self.device, dtype=paddle.long)
hyp_list.append(hyp_content)
hyps_pad = pad_sequence(hyp_list, batch_first=True, padding_value=self.model.ignore_id)
hyps_pad = pad_sequence(
hyp_list, batch_first=True, padding_value=self.model.ignore_id)
hyps_lens = paddle.to_tensor(
[len(hyp[0]) for hyp in hyps], place=self.device,
dtype=paddle.long) # (beam_size,)
@ -694,7 +693,6 @@ class PaddleASRConnectionHanddler:
logger.info(f"word time stamp: {self.word_time_stamp}")
class ASRServerExecutor(ASRExecutor):
def __init__(self):
super().__init__()

@ -17,12 +17,12 @@ from typing import List
from fastapi import APIRouter
from paddlespeech.cli.log import logger
from paddlespeech.server.restful.acs_api import router as acs_router
from paddlespeech.server.restful.asr_api import router as asr_router
from paddlespeech.server.restful.cls_api import router as cls_router
from paddlespeech.server.restful.text_api import router as text_router
from paddlespeech.server.restful.tts_api import router as tts_router
from paddlespeech.server.restful.vector_api import router as vec_router
from paddlespeech.server.restful.acs_api import router as acs_router
_router = APIRouter()

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
class Frame(object):
"""Represents a "frame" of audio data."""
@ -77,8 +78,8 @@ class ChunkBuffer(object):
offset = 0
while offset + self.window_bytes <= len(audio):
yield Frame(audio[offset:offset + self.window_bytes], self.timestamp,
self.window_sec)
yield Frame(audio[offset:offset + self.window_bytes],
self.timestamp, self.window_sec)
self.timestamp += self.shift_sec
offset += self.shift_bytes

@ -176,7 +176,10 @@ def main():
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.")
parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.")
"--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
args, _ = parser.parse_known_args()

@ -188,7 +188,10 @@ def main():
parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.")
"--nxpu",
type=int,
default=0,
help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu")

@ -98,7 +98,6 @@ requirements = {
}
def check_call(cmd: str, shell=False, executable=None):
try:
sp.check_call(
@ -112,6 +111,7 @@ def check_call(cmd: str, shell=False, executable=None):
file=sys.stderr)
raise e
def check_output(cmd: str, shell=False):
try:
out_bytes = sp.check_output(cmd.split())
@ -146,6 +146,7 @@ def _remove(files: str):
for f in files:
f.unlink()
################################# Install ##################################
@ -308,6 +309,5 @@ setup_info = dict(
]
})
with version_info():
setup(**setup_info)

@ -20,7 +20,6 @@ of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
from pathlib import Path
@ -106,6 +105,7 @@ def create_manifest(data_dir, manifest_path_prefix):
manifest_dir = os.path.dirname(manifest_path_prefix)
def prepare_dataset(url, md5sum, target_dir, manifest_path=None):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')

Loading…
Cancel
Save