[ASR] rm transformers import and modify variable name consistent with infer.py, test=asr (#2929)

* rm transformers import and modify variable name consistent with infer.py

* add condition ctc_prefix_beam_search decode.
pull/2937/head
zxcd 1 year ago committed by GitHub
parent 71bda24437
commit 004a4d6096
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,9 +23,9 @@ from contextlib import nullcontext
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle import paddle
import transformers
from hyperpyyaml import load_hyperpyyaml from hyperpyyaml import load_hyperpyyaml
from paddle import distributed as dist from paddle import distributed as dist
from paddlenlp.transformers import AutoTokenizer
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.io.dataloader import DataLoaderFactory
@ -530,8 +530,7 @@ class Wav2Vec2ASRTrainer(Trainer):
datasets = [train_data, valid_data, test_data] datasets = [train_data, valid_data, test_data]
# Defining tokenizer and loading it # Defining tokenizer and loading it
tokenizer = transformers.BertTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
'bert-base-chinese')
self.tokenizer = tokenizer self.tokenizer = tokenizer
# 2. Define audio pipeline: # 2. Define audio pipeline:
@data_pipeline.takes("wav") @data_pipeline.takes("wav")
@ -867,8 +866,7 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
vocab_list = self.vocab_list vocab_list = self.vocab_list
decode_batch_size = decode_cfg.decode_batch_size decode_batch_size = decode_cfg.decode_batch_size
with jsonlines.open( with jsonlines.open(self.args.result_file, 'w') as fout:
self.args.result_file, 'w', encoding='utf8') as fout:
for i, batch in enumerate(self.test_loader): for i, batch in enumerate(self.test_loader):
if self.use_sb: if self.use_sb:
metrics = self.sb_compute_metrics(**batch, fout=fout) metrics = self.sb_compute_metrics(**batch, fout=fout)

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from turtle import Turtle
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Tuple from typing import Tuple
@ -83,6 +84,7 @@ class Wav2vec2ASR(nn.Layer):
text_feature: Dict[str, int], text_feature: Dict[str, int],
decoding_method: str, decoding_method: str,
beam_size: int, beam_size: int,
tokenizer: str=None,
sb_pipeline=False): sb_pipeline=False):
batch_size = feats.shape[0] batch_size = feats.shape[0]
@ -93,12 +95,15 @@ class Wav2vec2ASR(nn.Layer):
logger.error(f"current batch_size is {batch_size}") logger.error(f"current batch_size is {batch_size}")
if decoding_method == 'ctc_greedy_search': if decoding_method == 'ctc_greedy_search':
if not sb_pipeline: if tokenizer is None and sb_pipeline is False:
hyps = self.ctc_greedy_search(feats) hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps] res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps] res_tokenids = [hyp for hyp in hyps]
else: else:
hyps = self.ctc_greedy_search(feats.unsqueeze(-1)) if sb_pipeline is True:
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
else:
hyps = self.ctc_greedy_search(feats)
res = [] res = []
res_tokenids = [] res_tokenids = []
for sequence in hyps: for sequence in hyps:
@ -123,13 +128,16 @@ class Wav2vec2ASR(nn.Layer):
# with other batch decoding mode # with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search': elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1 assert feats.shape[0] == 1
if not sb_pipeline: if tokenizer is None and sb_pipeline is False:
hyp = self.ctc_prefix_beam_search(feats, beam_size) hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)] res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp] res_tokenids = [hyp]
else: else:
hyp = self.ctc_prefix_beam_search( if sb_pipeline is True:
feats.unsqueeze(-1), beam_size) hyp = self.ctc_prefix_beam_search(
feats.unsqueeze(-1), beam_size)
else:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [] res = []
res_tokenids = [] res_tokenids = []
predicted_tokens = text_feature.convert_ids_to_tokens(hyp) predicted_tokens = text_feature.convert_ids_to_tokens(hyp)

@ -69,7 +69,6 @@ base = [
"paddleslim>=2.3.4", "paddleslim>=2.3.4",
"paddleaudio>=1.1.0", "paddleaudio>=1.1.0",
"hyperpyyaml", "hyperpyyaml",
"transformers",
] ]
server = ["pattern_singleton", "websockets"] server = ["pattern_singleton", "websockets"]

Loading…
Cancel
Save