quant with wav scp

pull/2568/head
Hui Zhang 3 years ago
parent 09a735af24
commit c83c9800cc

@ -1,7 +1,7 @@
#!/bin/bash #!/bin/bash
if [ $# != 4 ];then if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file" echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_scp"
exit -1 exit -1
fi fi
@ -11,16 +11,15 @@ echo "using $ngpu gpus..."
config_path=$1 config_path=$1
decode_config_path=$2 decode_config_path=$2
ckpt_prefix=$3 ckpt_prefix=$3
audio_file=$4 audio_scp=$4
mkdir -p data mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
if [ ! -f ${audio_file} ]; then if [ ! -f ${audio_scp} ]; then
echo "Plase input the right audio_file path" echo "Plase input the right audio_scp path"
exit 1 exit 1
fi fi
@ -49,7 +48,7 @@ for type in attention_rescoring; do
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \ --opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size} \ --opts decode.decode_batch_size ${batch_size} \
--audio_file ${audio_file} --audio_scp ${audio_scp}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"

@ -20,6 +20,7 @@ import paddle
import soundfile import soundfile
from paddleslim import PTQ from paddleslim import PTQ
from yacs.config import CfgNode from yacs.config import CfgNode
from kaldiio import ReadHelper
from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
@ -34,7 +35,7 @@ class U2Infer():
def __init__(self, config, args): def __init__(self, config, args):
self.args = args self.args = args
self.config = config self.config = config
self.audio_file = args.audio_file self.audio_scp = args.audio_scp
self.preprocess_conf = config.preprocess_config self.preprocess_conf = config.preprocess_config
self.preprocess_args = {"train": False} self.preprocess_args = {"train": False}
@ -63,133 +64,117 @@ class U2Infer():
self.model.set_state_dict(model_dict) self.model.set_state_dict(model_dict)
def run(self): def run(self):
check(args.audio_file) cnt = 0
with ReadHelper(f"scp:{self.audio_scp}") as reader:
with paddle.no_grad(): for key, (rate, audio) in reader:
# read assert rate == 16000
audio, sample_rate = soundfile.read( cnt += 1
self.audio_file, dtype="int16", always_2d=True) if cnt > args.num_utts:
audio = audio[:, 0] break
logger.info(f"audio shape: {audio.shape}")
with paddle.no_grad():
# fbank logger.info(f"audio shape: {audio.shape}")
feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}") # fbank
feat = self.preprocessing(audio, **self.preprocess_args)
ilen = paddle.to_tensor(feat.shape[0]) logger.info(f"feat shape: {feat.shape}")
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode ilen = paddle.to_tensor(feat.shape[0])
logger.info(f"decode cfg: {decode_config}") xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
reverse_weight = getattr(decode_config, 'reverse_weight', 0.0) decode_config = self.config.decode
result_transcripts = self.model.decode( logger.info(f"decode cfg: {decode_config}")
xs, result_transcripts = self.model.decode(
ilen, xs,
text_feature=self.text_feature, ilen,
decoding_method=decode_config.decoding_method, text_feature=self.text_feature,
beam_size=decode_config.beam_size, decoding_method=decode_config.decoding_method,
ctc_weight=decode_config.ctc_weight, beam_size=decode_config.beam_size,
decoding_chunk_size=decode_config.decoding_chunk_size, ctc_weight=decode_config.ctc_weight,
num_decoding_left_chunks=decode_config.num_decoding_left_chunks, decoding_chunk_size=decode_config.decoding_chunk_size,
simulate_streaming=decode_config.simulate_streaming, num_decoding_left_chunks=decode_config.num_decoding_left_chunks,
reverse_weight=reverse_weight) simulate_streaming=decode_config.simulate_streaming,
rsl = result_transcripts[0][0] reverse_weight=decode_config.reverse_weight)
utt = Path(self.audio_file).name rsl = result_transcripts[0][0]
logger.info(f"hyp: {utt} {rsl}") utt = key
# print(self.model) logger.info(f"hyp: {utt} {rsl}")
# print(self.model.forward_encoder_chunk) # print(self.model)
# print(self.model.forward_encoder_chunk)
logger.info("-------------start quant ----------------------")
batch_size = 1
feat_dim = 80 logger.info("-------------start quant ----------------------")
model_size = 512 batch_size = 1
num_left_chunks = -1 feat_dim = 80
reverse_weight = 0.3 model_size = 512
logger.info( num_left_chunks = -1
f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}, reverse_weight {reverse_weight}" reverse_weight = 0.3
) logger.info(
f"U2 Export Model Params: batch_size {batch_size}, feat_dim {feat_dim}, model_size {model_size}, num_left_chunks {num_left_chunks}, reverse_weight {reverse_weight}"
# ######################## self.model.forward_encoder_chunk ############ )
# input_spec = [
# # (T,), int16 # ######################## self.model.forward_encoder_chunk ############
# paddle.static.InputSpec(shape=[None], dtype='int16'), # input_spec = [
# ] # # (T,), int16
# self.model.forward_feature = paddle.jit.to_static( # paddle.static.InputSpec(shape=[None], dtype='int16'),
# self.model.forward_feature, input_spec=input_spec) # ]
# self.model.forward_feature = paddle.jit.to_static(
######################### self.model.forward_encoder_chunk ############ # self.model.forward_feature, input_spec=input_spec)
input_spec = [
# xs, (B, T, D) ######################### self.model.forward_encoder_chunk ############
paddle.static.InputSpec( input_spec = [
shape=[batch_size, None, feat_dim], dtype='float32'), # xs, (B, T, D)
# offset, int, but need be tensor paddle.static.InputSpec(
paddle.static.InputSpec(shape=[1], dtype='int32'), shape=[batch_size, None, feat_dim], dtype='float32'),
# required_cache_size, int # offset, int, but need be tensor
num_left_chunks, paddle.static.InputSpec(shape=[1], dtype='int32'),
# att_cache # required_cache_size, int
paddle.static.InputSpec( num_left_chunks,
shape=[None, None, None, None], dtype='float32'), # att_cache
# cnn_cache paddle.static.InputSpec(
paddle.static.InputSpec( shape=[None, None, None, None], dtype='float32'),
shape=[None, None, None, None], dtype='float32') # cnn_cache
] paddle.static.InputSpec(
self.model.forward_encoder_chunk = paddle.jit.to_static( shape=[None, None, None, None], dtype='float32')
self.model.forward_encoder_chunk, input_spec=input_spec) ]
self.model.forward_encoder_chunk = paddle.jit.to_static(
######################### self.model.ctc_activation ######################## self.model.forward_encoder_chunk, input_spec=input_spec)
input_spec = [
# encoder_out, (B,T,D) ######################### self.model.ctc_activation ########################
paddle.static.InputSpec( input_spec = [
shape=[batch_size, None, model_size], dtype='float32') # encoder_out, (B,T,D)
] paddle.static.InputSpec(
self.model.ctc_activation = paddle.jit.to_static( shape=[batch_size, None, model_size], dtype='float32')
self.model.ctc_activation, input_spec=input_spec) ]
self.model.ctc_activation = paddle.jit.to_static(
######################### self.model.forward_attention_decoder ######################## self.model.ctc_activation, input_spec=input_spec)
input_spec = [
# hyps, (B, U) ######################### self.model.forward_attention_decoder ########################
paddle.static.InputSpec(shape=[None, None], dtype='int64'), input_spec = [
# hyps_lens, (B,) # hyps, (B, U)
paddle.static.InputSpec(shape=[None], dtype='int64'), paddle.static.InputSpec(shape=[None, None], dtype='int64'),
# encoder_out, (B,T,D) # hyps_lens, (B,)
paddle.static.InputSpec( paddle.static.InputSpec(shape=[None], dtype='int64'),
shape=[batch_size, None, model_size], dtype='float32'), # encoder_out, (B,T,D)
reverse_weight paddle.static.InputSpec(
] shape=[batch_size, None, model_size], dtype='float32'),
self.model.forward_attention_decoder = paddle.jit.to_static( reverse_weight
self.model.forward_attention_decoder, input_spec=input_spec) ]
################################################################################ self.model.forward_attention_decoder = paddle.jit.to_static(
self.model.forward_attention_decoder, input_spec=input_spec)
# jit save ################################################################################
logger.info(f"export save: {self.args.export_path}")
config = { # jit save
'is_static': True, logger.info(f"export save: {self.args.export_path}")
'combine_params': True, config = {
'skip_forward': True 'is_static': True,
} 'combine_params': True,
self.ptq.save_quantized_model(self.model, self.args.export_path) 'skip_forward': True
# paddle.jit.save( }
# self.model, self.ptq.save_quantized_model(self.model, self.args.export_path)
# self.args.export_path, # paddle.jit.save(
# combine_params=True, # self.model,
# skip_forward=True) # self.args.export_path,
# combine_params=True,
# skip_forward=True)
def check(audio_file):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(audio_file)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the wav file, please check the audio file format")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
assert (sample_rate == 16000)
logger.info("The audio file format is right")
def main(config, args): def main(config, args):
@ -202,7 +187,9 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--result_file", type=str, help="path of save the asr result") "--result_file", type=str, help="path of save the asr result")
parser.add_argument( parser.add_argument(
"--audio_file", type=str, help="path of the input audio file") "--audio_scp", type=str, help="path of the input audio file")
parser.add_argument(
"--num_utts", type=int, default=200, help="num utts for quant calibrition.")
parser.add_argument( parser.add_argument(
"--export_path", "--export_path",
type=str, type=str,

Loading…
Cancel
Save