diff --git a/examples/wenetspeech/asr1/local/quant.sh b/examples/wenetspeech/asr1/local/quant.sh index 6a2a4c72b..ac854aaad 100755 --- a/examples/wenetspeech/asr1/local/quant.sh +++ b/examples/wenetspeech/asr1/local/quant.sh @@ -1,5 +1,6 @@ #!/bin/bash +# ./local/quant.sh conf/chunk_conformer_u2pp.yaml conf/tuning/chunk_decode.yaml exp/chunk_conformer_u2pp/checkpoints/avg_10 data/wav.aishell.test.scp if [ $# != 4 ];then echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_scp" exit -1 @@ -48,6 +49,7 @@ for type in attention_rescoring; do --checkpoint_path ${ckpt_prefix} \ --opts decode.decoding_method ${type} \ --opts decode.decode_batch_size ${batch_size} \ + --num_utts 200 \ --audio_scp ${audio_scp} if [ $? -ne 0 ]; then diff --git a/paddlespeech/s2t/exps/u2/bin/quant.py b/paddlespeech/s2t/exps/u2/bin/quant.py index 2f17dc252..71101e1c4 100644 --- a/paddlespeech/s2t/exps/u2/bin/quant.py +++ b/paddlespeech/s2t/exps/u2/bin/quant.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Evaluation for U2 model.""" +"""Quantzation U2 model.""" import paddle from kaldiio import ReadHelper from paddleslim import PTQ @@ -159,17 +159,12 @@ class U2Infer(): # jit save logger.info(f"export save: {self.args.export_path}") - config = { - 'is_static': True, - 'combine_params': True, - 'skip_forward': True - } - self.ptq.save_quantized_model(self.model, self.args.export_path) - # paddle.jit.save( - # self.model, - # self.args.export_path, - # combine_params=True, - # skip_forward=True) + self.ptq.save_quantized_model( + self.model, + self.args.export_path, + postprocess=False, + combine_params=True, + skip_forward=True) def main(config, args): @@ -191,7 +186,7 @@ if __name__ == "__main__": parser.add_argument( "--export_path", type=str, - default='export', + default='export.jit.quant', help="path of the input audio file") args = parser.parse_args() diff --git a/setup.py b/setup.py index 35668bddb..c1757b194 100644 --- a/setup.py +++ b/setup.py @@ -75,7 +75,7 @@ base = [ "braceexpand", "pyyaml", "pybind11", - "paddleslim==2.3.4", + "paddleslim==2.4.0", ] server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"]