using paddleslim 2.4 api

pull/2568/head
Hui Zhang 3 years ago
parent 6f59642efa
commit 2389ed6675

@ -1,5 +1,6 @@
#!/bin/bash #!/bin/bash
# ./local/quant.sh conf/chunk_conformer_u2pp.yaml conf/tuning/chunk_decode.yaml exp/chunk_conformer_u2pp/checkpoints/avg_10 data/wav.aishell.test.scp
if [ $# != 4 ];then if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_scp" echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_scp"
exit -1 exit -1
@ -48,6 +49,7 @@ for type in attention_rescoring; do
--checkpoint_path ${ckpt_prefix} \ --checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \ --opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size} \ --opts decode.decode_batch_size ${batch_size} \
--num_utts 200 \
--audio_scp ${audio_scp} --audio_scp ${audio_scp}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Evaluation for U2 model.""" """Quantzation U2 model."""
import paddle import paddle
from kaldiio import ReadHelper from kaldiio import ReadHelper
from paddleslim import PTQ from paddleslim import PTQ
@ -159,17 +159,12 @@ class U2Infer():
# jit save # jit save
logger.info(f"export save: {self.args.export_path}") logger.info(f"export save: {self.args.export_path}")
config = { self.ptq.save_quantized_model(
'is_static': True, self.model,
'combine_params': True, self.args.export_path,
'skip_forward': True postprocess=False,
} combine_params=True,
self.ptq.save_quantized_model(self.model, self.args.export_path) skip_forward=True)
# paddle.jit.save(
# self.model,
# self.args.export_path,
# combine_params=True,
# skip_forward=True)
def main(config, args): def main(config, args):
@ -191,7 +186,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--export_path", "--export_path",
type=str, type=str,
default='export', default='export.jit.quant',
help="path of the input audio file") help="path of the input audio file")
args = parser.parse_args() args = parser.parse_args()

@ -75,7 +75,7 @@ base = [
"braceexpand", "braceexpand",
"pyyaml", "pyyaml",
"pybind11", "pybind11",
"paddleslim==2.3.4", "paddleslim==2.4.0",
] ]
server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"] server = ["fastapi", "uvicorn", "pattern_singleton", "websockets"]

Loading…
Cancel
Save