Merge branch 'PaddlePaddle:develop' into develop

pull/2647/head
liangym 3 years ago committed by GitHub
commit 664aed45e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,8 @@
#!/bin/bash
# ./local/quant.sh conf/chunk_conformer_u2pp.yaml conf/tuning/chunk_decode.yaml exp/chunk_conformer_u2pp/checkpoints/avg_10 data/wav.aishell.test.scp
if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_scp"
exit -1
fi
@ -11,16 +12,15 @@ echo "using $ngpu gpus..."
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
audio_file=$4
audio_scp=$4
mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
if [ $? -ne 0 ]; then
exit 1
fi
if [ ! -f ${audio_file} ]; then
echo "Plase input the right audio_file path"
if [ ! -f ${audio_scp} ]; then
echo "Plase input the right audio_scp path"
exit 1
fi
@ -49,7 +49,8 @@ for type in attention_rescoring; do
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size} \
--audio_file ${audio_file}
--num_utts 200 \
--audio_scp ${audio_scp}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"

@ -54,3 +54,7 @@ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
# test a single .wav file
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi
if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
# export quant model, plesae see local/quant.sh
fi

@ -11,13 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for U2 model."""
import os
import sys
from pathlib import Path
"""Quantzation U2 model."""
import paddle
import soundfile
from kaldiio import ReadHelper
from paddleslim import PTQ
from yacs.config import CfgNode
@ -34,7 +30,7 @@ class U2Infer():
def __init__(self, config, args):
self.args = args
self.config = config
self.audio_file = args.audio_file
self.audio_scp = args.audio_scp
self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False}
@ -63,13 +59,15 @@ class U2Infer():
self.model.set_state_dict(model_dict)
def run(self):
check(args.audio_file)
cnt = 0
with ReadHelper(f"scp:{self.audio_scp}") as reader:
for key, (rate, audio) in reader:
assert rate == 16000
cnt += 1
if cnt > args.num_utts:
break
with paddle.no_grad():
# read
audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
# fbank
@ -80,7 +78,6 @@ class U2Infer():
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode
logger.info(f"decode cfg: {decode_config}")
reverse_weight = getattr(decode_config, 'reverse_weight', 0.0)
result_transcripts = self.model.decode(
xs,
ilen,
@ -89,11 +86,12 @@ class U2Infer():
beam_size=decode_config.beam_size,
ctc_weight=decode_config.ctc_weight,
decoding_chunk_size=decode_config.decoding_chunk_size,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
num_decoding_left_chunks=decode_config.
num_decoding_left_chunks,
simulate_streaming=decode_config.simulate_streaming,
reverse_weight=reverse_weight)
reverse_weight=decode_config.reverse_weight)
rsl = result_transcripts[0][0]
utt = Path(self.audio_file).name
utt = key
logger.info(f"hyp: {utt} {rsl}")
# print(self.model)
# print(self.model.forward_encoder_chunk)
@ -161,35 +159,12 @@ class U2Infer():
# jit save
logger.info(f"export save: {self.args.export_path}")
config = {
'is_static': True,
'combine_params': True,
'skip_forward': True
}
self.ptq.save_quantized_model(self.model, self.args.export_path)
# paddle.jit.save(
# self.model,
# self.args.export_path,
# combine_params=True,
# skip_forward=True)
def check(audio_file):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(audio_file)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the wav file, please check the audio file format")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
assert (sample_rate == 16000)
logger.info("The audio file format is right")
self.ptq.ptq._convert(self.model)
paddle.jit.save(
self.model,
self.args.export_path,
combine_params=True,
skip_forward=True)
def main(config, args):
@ -202,11 +177,16 @@ if __name__ == "__main__":
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
"--audio_scp", type=str, help="path of the input audio file")
parser.add_argument(
"--num_utts",
type=int,
default=200,
help="num utts for quant calibrition.")
parser.add_argument(
"--export_path",
type=str,
default='export',
default='export.jit.quant',
help="path of the input audio file")
args = parser.parse_args()

@ -16,6 +16,7 @@ paddlespeech asr --model conformer_aishell --input ./zh.wav
paddlespeech asr --model conformer_online_aishell --input ./zh.wav
paddlespeech asr --model conformer_online_wenetspeech --input ./zh.wav
paddlespeech asr --model conformer_online_multicn --input ./zh.wav
paddlespeech asr --model conformer_u2pp_online_wenetspeech --lang zh --input zh.wav
paddlespeech asr --model transformer_librispeech --lang en --input ./en.wav
paddlespeech asr --model deepspeech2offline_aishell --input ./zh.wav
paddlespeech asr --model deepspeech2online_wenetspeech --input ./zh.wav

Loading…
Cancel
Save