From ee894330a8b914b145715087f6ec39cffd4d2098 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Wed, 15 Feb 2023 09:29:05 +0000 Subject: [PATCH] rm transformers import and modify variable name consistent with infer.py --- paddlespeech/s2t/exps/wav2vec2/model.py | 8 +++----- .../s2t/models/wav2vec2/wav2vec2_ASR.py | 18 +++++++++++++----- setup.py | 1 - 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 878c0d84b..86b56b876 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -23,9 +23,9 @@ from contextlib import nullcontext import jsonlines import numpy as np import paddle -import transformers from hyperpyyaml import load_hyperpyyaml from paddle import distributed as dist +from paddlenlp.transformers import AutoTokenizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import DataLoaderFactory @@ -530,8 +530,7 @@ class Wav2Vec2ASRTrainer(Trainer): datasets = [train_data, valid_data, test_data] # Defining tokenizer and loading it - tokenizer = transformers.BertTokenizer.from_pretrained( - 'bert-base-chinese') + tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese') self.tokenizer = tokenizer # 2. Define audio pipeline: @data_pipeline.takes("wav") @@ -867,8 +866,7 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer): vocab_list = self.vocab_list decode_batch_size = decode_cfg.decode_batch_size - with jsonlines.open( - self.args.result_file, 'w', encoding='utf8') as fout: + with jsonlines.open(self.args.result_file, 'w') as fout: for i, batch in enumerate(self.test_loader): if self.use_sb: metrics = self.sb_compute_metrics(**batch, fout=fout) diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index f91a41c32..059c0d909 100755 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict +from turtle import Turtle from typing import Dict from typing import List from typing import Tuple @@ -83,6 +84,7 @@ class Wav2vec2ASR(nn.Layer): text_feature: Dict[str, int], decoding_method: str, beam_size: int, + tokenizer: str=None, sb_pipeline=False): batch_size = feats.shape[0] @@ -93,12 +95,15 @@ class Wav2vec2ASR(nn.Layer): logger.error(f"current batch_size is {batch_size}") if decoding_method == 'ctc_greedy_search': - if not sb_pipeline: + if tokenizer is None and sb_pipeline is False: hyps = self.ctc_greedy_search(feats) res = [text_feature.defeaturize(hyp) for hyp in hyps] res_tokenids = [hyp for hyp in hyps] else: - hyps = self.ctc_greedy_search(feats.unsqueeze(-1)) + if sb_pipeline is True: + hyps = self.ctc_greedy_search(feats.unsqueeze(-1)) + else: + hyps = self.ctc_greedy_search(feats) res = [] res_tokenids = [] for sequence in hyps: @@ -123,13 +128,16 @@ class Wav2vec2ASR(nn.Layer): # with other batch decoding mode elif decoding_method == 'ctc_prefix_beam_search': assert feats.shape[0] == 1 - if not sb_pipeline: + if tokenizer is None: hyp = self.ctc_prefix_beam_search(feats, beam_size) res = [text_feature.defeaturize(hyp)] res_tokenids = [hyp] else: - hyp = self.ctc_prefix_beam_search( - feats.unsqueeze(-1), beam_size) + if sb_pipeline is True: + hyp = self.ctc_prefix_beam_search( + feats.unsqueeze(-1), beam_size) + else: + hyp = self.ctc_prefix_beam_search(feats, beam_size) res = [] res_tokenids = [] predicted_tokens = text_feature.convert_ids_to_tokens(hyp) diff --git a/setup.py b/setup.py index 014b0ed91..69739b3b8 100644 --- a/setup.py +++ b/setup.py @@ -69,7 +69,6 @@ base = [ "paddleslim>=2.3.4", "paddleaudio>=1.1.0", "hyperpyyaml", - "transformers", ] server = ["pattern_singleton", "websockets"]