[ASR] rm transformers import and modify variable name consistent with infer.py, test=asr (#2929)

* rm transformers import and modify variable name consistent with infer.py

* add condition ctc_prefix_beam_search decode.
pull/2937/head
zxcd 1 year ago committed by GitHub
parent 71bda24437
commit 004a4d6096
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,9 +23,9 @@ from contextlib import nullcontext
import jsonlines
import numpy as np
import paddle
import transformers
from hyperpyyaml import load_hyperpyyaml
from paddle import distributed as dist
from paddlenlp.transformers import AutoTokenizer
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
@ -530,8 +530,7 @@ class Wav2Vec2ASRTrainer(Trainer):
datasets = [train_data, valid_data, test_data]
# Defining tokenizer and loading it
tokenizer = transformers.BertTokenizer.from_pretrained(
'bert-base-chinese')
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
self.tokenizer = tokenizer
# 2. Define audio pipeline:
@data_pipeline.takes("wav")
@ -867,8 +866,7 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
vocab_list = self.vocab_list
decode_batch_size = decode_cfg.decode_batch_size
with jsonlines.open(
self.args.result_file, 'w', encoding='utf8') as fout:
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
if self.use_sb:
metrics = self.sb_compute_metrics(**batch, fout=fout)

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from turtle import Turtle
from typing import Dict
from typing import List
from typing import Tuple
@ -83,6 +84,7 @@ class Wav2vec2ASR(nn.Layer):
text_feature: Dict[str, int],
decoding_method: str,
beam_size: int,
tokenizer: str=None,
sb_pipeline=False):
batch_size = feats.shape[0]
@ -93,12 +95,15 @@ class Wav2vec2ASR(nn.Layer):
logger.error(f"current batch_size is {batch_size}")
if decoding_method == 'ctc_greedy_search':
if not sb_pipeline:
if tokenizer is None and sb_pipeline is False:
hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
else:
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
if sb_pipeline is True:
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
else:
hyps = self.ctc_greedy_search(feats)
res = []
res_tokenids = []
for sequence in hyps:
@ -123,13 +128,16 @@ class Wav2vec2ASR(nn.Layer):
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1
if not sb_pipeline:
if tokenizer is None and sb_pipeline is False:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
else:
hyp = self.ctc_prefix_beam_search(
feats.unsqueeze(-1), beam_size)
if sb_pipeline is True:
hyp = self.ctc_prefix_beam_search(
feats.unsqueeze(-1), beam_size)
else:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = []
res_tokenids = []
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)

@ -69,7 +69,6 @@ base = [
"paddleslim>=2.3.4",
"paddleaudio>=1.1.0",
"hyperpyyaml",
"transformers",
]
server = ["pattern_singleton", "websockets"]

Loading…
Cancel
Save