add collactor and evaluation code for ST

pull/743/head
Junkun 3 years ago
parent 0323151912
commit ac0ae57ef2

@ -24,7 +24,6 @@ from typing import Tuple
import numpy as np
import paddle
import sacrebleu
from paddle import distributed as dist
from paddle.io import DataLoader
from yacs.config import CfgNode
@ -32,6 +31,7 @@ from yacs.config import CfgNode
from deepspeech.io.collator_st import KaldiPrePorocessedCollator
from deepspeech.io.collator_st import SpeechCollator
from deepspeech.io.collator_st import TripletKaldiPrePorocessedCollator
from deepspeech.io.collator_st import TripletSpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.io.dataset import TripletManifestDataset
from deepspeech.io.sampler import SortagradBatchSampler
@ -40,6 +40,7 @@ from deepspeech.models.u2_st import U2STModel
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.trainer import Trainer
from deepspeech.utils import bleu_score
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
@ -248,7 +249,11 @@ class U2STTrainer(Trainer):
dev_dataset = Dataset.from_config(config)
if config.collator.raw_wav:
TestCollator = Collator = SpeechCollator
if config.model.model_conf.asr_weight > 0.:
Collator = TripletSpeechCollator
TestCollator = SpeechCollator
else:
TestCollator = Collator = SpeechCollator
# Not yet implement the mtl loader for raw_wav.
else:
if config.model.model_conf.asr_weight > 0.:
@ -393,7 +398,7 @@ class U2STTester(U2STTrainer):
lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm', # Filepath for language model.
decoding_method='attention', # Decoding method. Options: 'attention', 'ctc_greedy_search',
# 'ctc_prefix_beam_search', 'attention_rescoring'
error_rate_type='wer', # Error rate type for evaluation. Options `wer`, 'cer'
error_rate_type='bleu', # Error rate type for evaluation. Options `bleu`, 'char_bleu'
num_proc_bsearch=8, # # of CPUs for beam search.
beam_size=10, # Beam search width.
batch_size=16, # decoding batch size
@ -428,10 +433,10 @@ class U2STTester(U2STTrainer):
audio_len,
texts,
texts_len,
bleu_func,
fout=None):
cfg = self.config.decoding
len_refs, num_ins = 0, 0
bleu_func = sacrebleu.corpus_bleu
start_time = time.time()
text_feature = self.test_loader.collate_fn.text_feature
@ -487,6 +492,9 @@ class U2STTester(U2STTrainer):
self.model.eval()
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
cfg = self.config.decoding
bleu_func = bleu_score.char_bleu if cfg.error_rate_type == 'char-bleu' else bleu_score.bleu
stride_ms = self.test_loader.collate_fn.stride_ms
hyps, refs = [], []
len_refs, num_ins = 0, 0
@ -495,7 +503,7 @@ class U2STTester(U2STTrainer):
with open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_translation_metrics(
*batch, fout=fout)
*batch, bleu_func=bleu_func, fout=fout)
hyps += metrics['hyps']
refs += metrics['refs']
bleu = metrics['bleu']
@ -504,19 +512,16 @@ class U2STTester(U2STTrainer):
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
rtf = num_time / (num_frames * stride_ms)
logger.info("RTF: %f, BELU (%d) = %f" %
(rtf, num_ins, bleu))
logger.info("RTF: %f, BELU (%d) = %f" % (rtf, num_ins, bleu))
rtf = num_time / (num_frames * stride_ms)
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "RTF: {}, ".format(rtf)
msg += "Test set [%s]: %s" % (
len(hyps), str(sacrebleu.corpus_bleu(hyps, [refs])))
msg += "Test set [%s]: %s" % (len(hyps), str(bleu_func(hyps, [refs])))
logger.info(msg)
bleu_meta_path = os.path.splitext(
self.args.result_file)[0] + '.bleu'
bleu_meta_path = os.path.splitext(self.args.result_file)[0] + '.bleu'
err_type_str = "BLEU"
with open(bleu_meta_path, 'w') as f:
data = json.dumps({
@ -527,7 +532,7 @@ class U2STTester(U2STTrainer):
"rtf":
rtf,
err_type_str:
sacrebleu.corpus_bleu(hyps, [refs]).score,
bleu_func(hyps, [refs]).score,
"dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
"process_hour":
num_time / 1000.0 / 3600.0,

@ -0,0 +1,666 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from collections import namedtuple
from typing import Optional
from typing import Tuple
import kaldiio
import numpy as np
from yacs.config import CfgNode
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.frontend.normalizer import FeatureNormalizer
from deepspeech.frontend.speech import SpeechSegment
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.io.utility import pad_sequence
from deepspeech.utils.log import Log
__all__ = ["SpeechCollator", "KaldiPrePorocessedCollator"]
logger = Log(__name__).getlog()
# namedtupe need global for pickle.
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
class SpeechCollator():
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
mean_std_filepath="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0, # feature dither
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'mean_std_filepath' in config.collator
assert 'vocab_filepath' in config.collator
assert 'specgram_type' in config.collator
assert 'n_fft' in config.collator
assert config.collator
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
mean_std_filepath=config.collator.mean_std_filepath,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
specgram_type=config.collator.specgram_type,
feat_dim=config.collator.feat_dim,
delta_delta=config.collator.delta_delta,
stride_ms=config.collator.stride_ms,
window_ms=config.collator.window_ms,
n_fft=config.collator.n_fft,
max_freq=config.collator.max_freq,
target_sample_rate=config.collator.target_sample_rate,
use_dB_normalization=config.collator.use_dB_normalization,
target_dB=config.collator.target_dB,
dither=config.collator.dither,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
def __init__(
self,
aug_file,
mean_std_filepath,
vocab_filepath,
spm_model_prefix,
random_seed=0,
unit_type="char",
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
feat_dim=0, # 'mfcc', 'fbank'
delta_delta=False, # 'mfcc', 'fbank'
stride_ms=10.0, # ms
window_ms=20.0, # ms
n_fft=None, # fft points
max_freq=None, # None for samplerate/2
target_sample_rate=16000, # target sample rate
use_dB_normalization=True,
target_dB=-20,
dither=1.0,
keep_transcription_text=True):
"""SpeechCollator Collator
Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
window_ms (float, optional): window size in ms. Defaults to 20.0.
n_fft (int, optional): fft points for rfft. Defaults to None.
max_freq (int, optional): max cut freq. Defaults to None.
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
target_dB (int, optional): target dB. Defaults to -20.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self._keep_transcription_text = keep_transcription_text
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None
self._stride_ms = stride_ms
self._target_sample_rate = target_sample_rate
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
def _parse_tar(self, file):
"""Parse a tar file to get a tarfile object
and a map containing tarinfoes
"""
result = {}
f = tarfile.open(file)
for tarinfo in f.getmembers():
result[tarinfo.name] = tarinfo
return f, result
def _subfile_from_tar(self, file):
"""Get subfile object from tar.
It will return a subfile object from tar file
and cached tar file info for next reading request.
"""
tarpath, filename = file.split(':', 1)[1].split('#', 1)
if 'tar2info' not in self._local_data.__dict__:
self._local_data.tar2info = {}
if 'tar2object' not in self._local_data.__dict__:
self._local_data.tar2object = {}
if tarpath not in self._local_data.tar2info:
object, infoes = self._parse_tar(tarpath)
self._local_data.tar2info[tarpath] = infoes
self._local_data.tar2object[tarpath] = object
return self._local_data.tar2object[tarpath].extractfile(
self._local_data.tar2info[tarpath][filename])
def process_utterance(self, audio_file, translation):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), translation)
else:
speech_segment = SpeechSegment.from_file(audio_file, translation)
# audio augment
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, translation_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, translation_part
def __call__(self, batch):
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
text (List[int] or str): shape (U,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios = []
audio_lens = []
texts = []
text_lens = []
utts = []
for utt, audio, text in batch:
audio, text = self.process_utterance(audio, text)
#utt
utts.append(utt)
# audio
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = []
if self._keep_transcription_text:
assert isinstance(text, str), (type(text), text)
tokens = [ord(t) for t in text]
else:
tokens = text # token ids
tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
tokens, dtype=np.int64)
texts.append(tokens)
text_lens.append(tokens.shape[0])
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
audio_lens = np.array(audio_lens).astype(np.int64)
padded_texts = pad_sequence(
texts, padding_value=IGNORE_ID).astype(np.int64)
text_lens = np.array(text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, padded_texts, text_lens
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
return self._speech_featurizer.vocab_list
@property
def vocab_dict(self):
return self._speech_featurizer.vocab_dict
@property
def text_feature(self):
return self._speech_featurizer.text_feature
@property
def feature_size(self):
return self._speech_featurizer.feature_size
@property
def stride_ms(self):
return self._speech_featurizer.stride_ms
class TripletSpeechCollator(SpeechCollator):
def process_utterance(self, audio_file, translation, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of audio file.
:type audio_file: str | file
:param translation: translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
speech_segment = SpeechSegment.from_file(
self._subfile_from_tar(audio_file), translation)
else:
speech_segment = SpeechSegment.from_file(audio_file, translation)
# audio augment
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, translation_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
transcript_part = self._speech_featurizer._text_featurizer.featurize(
transcript)
if self._normalizer:
specgram = self._normalizer.apply(specgram)
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
return specgram, translation_part, transcript_part
def __call__(self, batch):
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
text (List[int] or str): shape (U,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
text : (B, Umax)
text_lens: (B)
"""
audios = []
audio_lens = []
translation_text = []
translation_text_lens = []
transcription_text = []
transcription_text_lens = []
utts = []
for utt, audio, translation, transcription in batch:
audio, translation, transcription = self.process_utterance(
audio, translation, transcription)
#utt
utts.append(utt)
# audio
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = [[], []]
for idx, text in enumerate([translation, transcription]):
if self._keep_transcription_text:
assert isinstance(text, str), (type(text), text)
tokens[idx] = [ord(t) for t in text]
else:
tokens[idx] = text # token ids
tokens[idx] = tokens[idx] if isinstance(
tokens[idx], np.ndarray) else np.array(
tokens[idx], dtype=np.int64)
translation_text.append(tokens[0])
translation_text_lens.append(tokens[0].shape[0])
transcription_text.append(tokens[1])
transcription_text_lens.append(tokens[1].shape[0])
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
audio_lens = np.array(audio_lens).astype(np.int64)
padded_translation = pad_sequence(
translation_text, padding_value=IGNORE_ID).astype(np.int64)
translation_lens = np.array(translation_text_lens).astype(np.int64)
padded_transcription = pad_sequence(
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, (
padded_translation, padded_transcription), (translation_lens,
transcription_lens)
class KaldiPrePorocessedCollator(SpeechCollator):
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
default = CfgNode(
dict(
augmentation_config="",
random_seed=0,
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
feat_dim=0,
stride_ms=10.0,
keep_transcription_text=False))
if config is not None:
config.merge_from_other_cfg(default)
return default
@classmethod
def from_config(cls, config):
"""Build a SpeechCollator object from a config.
Args:
config (yacs.config.CfgNode): configs object.
Returns:
SpeechCollator: collator object.
"""
assert 'augmentation_config' in config.collator
assert 'keep_transcription_text' in config.collator
assert 'vocab_filepath' in config.collator
assert config.collator
if isinstance(config.collator.augmentation_config, (str, bytes)):
if config.collator.augmentation_config:
aug_file = io.open(
config.collator.augmentation_config,
mode='r',
encoding='utf8')
else:
aug_file = io.StringIO(initial_value='{}', newline='')
else:
aug_file = config.collator.augmentation_config
assert isinstance(aug_file, io.StringIO)
speech_collator = cls(
aug_file=aug_file,
random_seed=0,
unit_type=config.collator.unit_type,
vocab_filepath=config.collator.vocab_filepath,
spm_model_prefix=config.collator.spm_model_prefix,
feat_dim=config.collator.feat_dim,
stride_ms=config.collator.stride_ms,
keep_transcription_text=config.collator.keep_transcription_text)
return speech_collator
def __init__(self,
aug_file,
vocab_filepath,
spm_model_prefix,
random_seed=0,
unit_type="char",
feat_dim=0,
stride_ms=10.0,
keep_transcription_text=True):
"""SpeechCollator Collator
Args:
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
random_seed (int, optional): for random generator. Defaults to 0.
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
if ``keep_transcription_text`` is False, text is token ids else is raw string.
Do augmentations
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one batch.
"""
self._keep_transcription_text = keep_transcription_text
self._feat_dim = feat_dim
self._stride_ms = stride_ms
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed)
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
spm_model_prefix)
def process_utterance(self, audio_file, translation):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of kaldi processed feature.
:type audio_file: str | file
:param translation: Translation text.
:type translation: str
:return: Tuple of audio feature tensor and data of translation part,
where translation part could be token ids or text.
:rtype: tuple of (2darray, list)
"""
specgram = kaldiio.load_mat(audio_file)
specgram = specgram.transpose([1, 0])
assert specgram.shape[
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[0])
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
if self._keep_transcription_text:
return specgram, translation
else:
text_ids = self._text_featurizer.featurize(translation)
return specgram, text_ids
@property
def manifest(self):
return self._manifest
@property
def vocab_size(self):
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
return self._text_featurizer.vocab_list
@property
def vocab_dict(self):
return self._text_featurizer.vocab_dict
@property
def text_feature(self):
return self._text_featurizer
@property
def feature_size(self):
return self._feat_dim
@property
def stride_ms(self):
return self._stride_ms
class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
def process_utterance(self, audio_file, translation, transcript):
"""Load, augment, featurize and normalize for speech data.
:param audio_file: Filepath or file object of kali processed feature.
:type audio_file: str | file
:param translation: Translation text.
:type translation: str
:param transcript: Transcription text.
:type transcript: str
:return: Tuple of audio feature tensor and data of translation and transcription parts,
where translation and transcription parts could be token ids or text.
:rtype: tuple of (2darray, (list, list))
"""
specgram = kaldiio.load_mat(audio_file)
specgram = specgram.transpose([1, 0])
assert specgram.shape[
0] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
self._feat_dim, specgram.shape[0])
# specgram augment
specgram = self._augmentation_pipeline.transform_feature(specgram)
specgram = specgram.transpose([1, 0])
if self._keep_transcription_text:
return specgram, translation, transcript
else:
translation_text_ids = self._text_featurizer.featurize(translation)
transcript_text_ids = self._text_featurizer.featurize(transcript)
return specgram, translation_text_ids, transcript_text_ids
def __call__(self, batch):
"""batch examples
Args:
batch ([List]): batch is (audio, text)
audio (np.ndarray) shape (D, T)
translation (List[int] or str): shape (U,)
transcription (List[int] or str): shape (V,)
Returns:
tuple(audio, text, audio_lens, text_lens): batched data.
audio : (B, Tmax, D)
audio_lens: (B)
translation_text : (B, Umax)
translation_text_lens: (B)
transcription_text : (B, Vmax)
transcription_text_lens: (B)
"""
audios = []
audio_lens = []
translation_text = []
translation_text_lens = []
transcription_text = []
transcription_text_lens = []
utts = []
for utt, audio, translation, transcription in batch:
audio, translation, transcription = self.process_utterance(
audio, translation, transcription)
#utt
utts.append(utt)
# audio
audios.append(audio) # [T, D]
audio_lens.append(audio.shape[0])
# text
# for training, text is token ids
# else text is string, convert to unicode ord
tokens = [[], []]
for idx, text in enumerate([translation, transcription]):
if self._keep_transcription_text:
assert isinstance(text, str), (type(text), text)
tokens[idx] = [ord(t) for t in text]
else:
tokens[idx] = text # token ids
tokens[idx] = tokens[idx] if isinstance(
tokens[idx], np.ndarray) else np.array(
tokens[idx], dtype=np.int64)
translation_text.append(tokens[0])
translation_text_lens.append(tokens[0].shape[0])
transcription_text.append(tokens[1])
transcription_text_lens.append(tokens[1].shape[0])
padded_audios = pad_sequence(
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
audio_lens = np.array(audio_lens).astype(np.int64)
padded_translation = pad_sequence(
translation_text, padding_value=IGNORE_ID).astype(np.int64)
translation_lens = np.array(translation_text_lens).astype(np.int64)
padded_transcription = pad_sequence(
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
return utts, padded_audios, audio_lens, (
padded_translation, padded_transcription), (translation_lens,
transcription_lens)

@ -19,9 +19,7 @@ from yacs.config import CfgNode
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log
__all__ = [
"ManifestDataset",
]
__all__ = ["ManifestDataset", "TripletManifestDataset"]
logger = Log(__name__).getlog()
@ -105,3 +103,16 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx):
instance = self._manifest[idx]
return instance["utt"], instance["feat"], instance["text"]
class TripletManifestDataset(ManifestDataset):
"""
For Joint Training of Speech Translation and ASR.
text: translation,
text1: transcript.
"""
def __getitem__(self, idx):
instance = self._manifest[idx]
return instance["utt"], instance["feat"], instance["text"], instance[
"text1"]

@ -0,0 +1,734 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""U2 ASR Model
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
(https://arxiv.org/pdf/2012.05481.pdf)
"""
import sys
import time
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import paddle
from paddle import jit
from paddle import nn
from yacs.config import CfgNode
from deepspeech.frontend.utility import IGNORE_ID
from deepspeech.frontend.utility import load_cmvn
from deepspeech.modules.cmvn import GlobalCMVN
from deepspeech.modules.ctc import CTCDecoder
from deepspeech.modules.decoder import TransformerDecoder
from deepspeech.modules.encoder import ConformerEncoder
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.loss import LabelSmoothingLoss
from deepspeech.modules.mask import make_pad_mask
from deepspeech.modules.mask import mask_finished_preds
from deepspeech.modules.mask import mask_finished_scores
from deepspeech.modules.mask import subsequent_mask
from deepspeech.utils import checkpoint
from deepspeech.utils import layer_tools
from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
from deepspeech.utils.log import Log
from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import log_add
__all__ = ["U2STModel", "U2STInferModel"]
logger = Log(__name__).getlog()
class U2STBaseModel(nn.Module):
"""CTC-Attention hybrid Encoder-Decoder model"""
@classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
# network architecture
default = CfgNode()
# allow add new item when merge_with_file
default.cmvn_file = ""
default.cmvn_file_type = "json"
default.input_dim = 0
default.output_dim = 0
# encoder related
default.encoder = 'transformer'
default.encoder_conf = CfgNode(
dict(
output_size=256, # dimension of attention
attention_heads=4,
linear_units=2048, # the number of units of position-wise feed forward
num_blocks=12, # the number of encoder blocks
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer='conv2d', # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before=True,
# use_cnn_module=True,
# cnn_module_kernel=15,
# activation_type='swish',
# pos_enc_layer_type='rel_pos',
# selfattention_layer_type='rel_selfattn',
))
# decoder related
default.decoder = 'transformer'
default.decoder_conf = CfgNode(
dict(
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
self_attention_dropout_rate=0.0,
src_attention_dropout_rate=0.0, ))
# hybrid CTC/attention
default.model_conf = CfgNode(
dict(
asr_weight=0.0,
ctc_weight=0.0,
lsm_weight=0.1, # label smoothing option
length_normalized_loss=False, ))
if config is not None:
config.merge_from_other_cfg(default)
return default
def __init__(self,
vocab_size: int,
encoder: TransformerEncoder,
st_decoder: TransformerDecoder,
decoder: TransformerDecoder=None,
ctc: CTCDecoder=None,
ctc_weight: float=0.0,
asr_weight: float=0.0,
ignore_id: int=IGNORE_ID,
lsm_weight: float=0.0,
length_normalized_loss: bool=False):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.asr_weight = asr_weight
self.encoder = encoder
self.st_decoder = st_decoder
self.decoder = decoder
self.ctc = ctc
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss, )
def forward(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
asr_text: paddle.Tensor=None,
asr_text_lengths: paddle.Tensor=None,
) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor], Optional[
paddle.Tensor]]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
Returns:
total_loss, attention_loss, ctc_loss
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
text.shape, text_lengths.shape)
# 1. Encoder
start = time.time()
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_time = time.time() - start
#logger.debug(f"encoder time: {encoder_time}")
#TODO(Hui Zhang): sum not support bool type
#encoder_out_lens = encoder_mask.squeeze(1).sum(1) #[B, 1, T] -> [B]
encoder_out_lens = encoder_mask.squeeze(1).cast(paddle.int64).sum(
1) #[B, 1, T] -> [B]
# 2a. ST-decoder branch
start = time.time()
loss_st, acc_st = self._calc_st_loss(encoder_out, encoder_mask, text,
text_lengths)
decoder_time = time.time() - start
loss_asr_att = None
loss_asr_ctc = None
# 2b. ASR Attention-decoder branch
if self.asr_weight > 0.:
if self.ctc_weight != 1.0:
start = time.time()
loss_asr_att, acc_att = self._calc_att_loss(
encoder_out, encoder_mask, asr_text, asr_text_lengths)
decoder_time = time.time() - start
# 2c. CTC branch
if self.ctc_weight != 0.0:
start = time.time()
loss_asr_ctc = self.ctc(encoder_out, encoder_out_lens, asr_text,
asr_text_lengths)
ctc_time = time.time() - start
if loss_asr_ctc is None:
loss_asr = loss_asr_att
elif loss_asr_att is None:
loss_asr = loss_asr_ctc
else:
loss_asr = self.ctc_weight * loss_asr_ctc + (1 - self.ctc_weight
) * loss_asr_att
loss = self.asr_weight * loss_asr + (1 - self.asr_weight) * loss_st
else:
loss = loss_st
return loss, loss_st, loss_asr_att, loss_asr_ctc
def _calc_st_loss(
self,
encoder_out: paddle.Tensor,
encoder_mask: paddle.Tensor,
ys_pad: paddle.Tensor,
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id, )
return loss_att, acc_att
def _calc_att_loss(
self,
encoder_out: paddle.Tensor,
encoder_mask: paddle.Tensor,
ys_pad: paddle.Tensor,
ys_pad_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, float]:
"""Calc attention loss.
Args:
encoder_out (paddle.Tensor): [B, Tmax, D]
encoder_mask (paddle.Tensor): [B, 1, Tmax]
ys_pad (paddle.Tensor): [B, Umax]
ys_pad_lens (paddle.Tensor): [B]
Returns:
Tuple[paddle.Tensor, float]: attention_loss, accuracy rate
"""
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos,
self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad,
ys_in_lens)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id, )
return loss_att, acc_att
def _forward_encoder(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encoder pass.
Args:
speech (paddle.Tensor): [B, Tmax, D]
speech_lengths (paddle.Tensor): [B]
decoding_chunk_size (int, optional): chuck size. Defaults to -1.
num_decoding_left_chunks (int, optional): nums chunks. Defaults to -1.
simulate_streaming (bool, optional): streaming or not. Defaults to False.
Returns:
Tuple[paddle.Tensor, paddle.Tensor]:
encoder hiddens (B, Tmax, D),
encoder hiddens mask (B, 1, Tmax).
"""
# Let's assume B = batch_size
# 1. Encoder
if simulate_streaming and decoding_chunk_size > 0:
encoder_out, encoder_mask = self.encoder.forward_chunk_by_chunk(
speech,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
else:
encoder_out, encoder_mask = self.encoder(
speech,
speech_lengths,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
return encoder_out, encoder_mask
def translate(
self,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
beam_size: int=10,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False, ) -> paddle.Tensor:
""" Apply beam search on attention decoder
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
paddle.Tensor: decoding result, (batch, max_result_len)
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.place
batch_size = speech.shape[0]
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks,
simulate_streaming) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
encoder_dim = encoder_out.size(2)
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
encoder_mask = encoder_mask.unsqueeze(1).repeat(
1, beam_size, 1, 1).view(running_size, 1,
maxlen) # (B*N, 1, max_len)
hyps = paddle.ones(
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
# log scale score
scores = paddle.to_tensor(
[0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float)
scores = scores.to(device).repeat(batch_size).unsqueeze(1).to(
device) # (B*N, 1)
end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1)
cache: Optional[List[paddle.Tensor]] = None
# 2. Decoder forward step by step
for i in range(1, maxlen + 1):
# Stop if all batch and all beam produce eos
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab)
logp, cache = self.st_decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
# 2.3 Seconde beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
scores = scores.view(-1, 1) # (B*N, 1)
# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
1, beam_size) # (B, N)
base_k_index = base_k_index * beam_size * beam_size
best_k_index = base_k_index.view(-1) + offset_k_index.view(
-1) # (B*N)
# 2.5 Update best hyps
best_k_pred = paddle.index_select(
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
best_hyps_index = best_k_index // beam_size
last_best_k_hyps = paddle.index_select(
hyps, index=best_hyps_index, axis=0) # (B*N, i)
hyps = paddle.cat(
(last_best_k_hyps, best_k_pred.view(-1, 1)),
dim=1) # (B*N, i+1)
# 2.6 Update end flag
end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1)
# 3. Select best of best
scores = scores.view(batch_size, beam_size)
# TODO: length normalization
best_index = paddle.argmax(scores, axis=-1).long() # (B)
best_hyps_index = best_index + paddle.arange(
batch_size, dtype=paddle.long) * beam_size
best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0)
best_hyps = best_hyps[:, 1:]
return best_hyps
@jit.export
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@jit.export
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@jit.export
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@jit.export
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
@jit.export
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
subsampling_cache: Optional[paddle.Tensor]=None,
elayers_output_cache: Optional[List[paddle.Tensor]]=None,
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
paddle.Tensor]]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (paddle.Tensor): chunk input
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
elayers_output_cache (Optional[List[paddle.Tensor]]):
transformer/conformer encoder layers output cache
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
cnn cache
Returns:
paddle.Tensor: output, it ranges from time 0 to current chunk.
paddle.Tensor: subsampling cache
List[paddle.Tensor]: attention cache
List[paddle.Tensor]: conformer cnn cache
"""
return self.encoder.forward_chunk(
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
@jit.export
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output
Returns:
paddle.Tensor: activation before ctc
"""
return self.ctc.log_softmax(xs)
@jit.export
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
hyps_lens: paddle.Tensor,
encoder_out: paddle.Tensor, ) -> paddle.Tensor:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
hyps (paddle.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining, (B, T)
hyps_lens (paddle.Tensor): length of each hyp in hyps, (B)
encoder_out (paddle.Tensor): corresponding encoder output, (B=1, T, D)
Returns:
paddle.Tensor: decoder output, (B, L)
"""
assert encoder_out.size(0) == 1
num_hyps = hyps.size(0)
assert hyps_lens.size(0) == num_hyps
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
# (B, 1, T)
encoder_mask = paddle.ones(
[num_hyps, 1, encoder_out.size(1)], dtype=paddle.bool)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps,
hyps_lens)
decoder_out = paddle.nn.functional.log_softmax(decoder_out, dim=-1)
return decoder_out
@paddle.no_grad()
def decode(self,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
lang_model_path: str,
beam_alpha: float,
beam_beta: float,
beam_size: int,
cutoff_prob: float,
cutoff_top_n: int,
num_processes: int,
ctc_weight: float=0.0,
decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False):
"""u2 decoding.
Args:
feats (Tenosr): audio features, (B, T, D)
feats_lengths (Tenosr): (B)
text_feature (TextFeaturizer): text feature object.
decoding_method (str): decoding mode, e.g.
'fullsentence',
'simultaneous'
lang_model_path (str): lm path.
beam_alpha (float): lm weight.
beam_beta (float): length penalty.
beam_size (int): beam size for search
cutoff_prob (float): for prune.
cutoff_top_n (int): for prune.
num_processes (int):
ctc_weight (float, optional): ctc weight for attention rescoring decode mode. Defaults to 0.0.
decoding_chunk_size (int, optional): decoding chunk size. Defaults to -1.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here.
num_decoding_left_chunks (int, optional):
number of left chunks for decoding. Defaults to -1.
simulate_streaming (bool, optional): simulate streaming inference. Defaults to False.
Raises:
ValueError: when not support decoding_method.
Returns:
List[List[int]]: transcripts.
"""
batch_size = feats.size(0)
if decoding_method == 'fullsentence':
hyps = self.translate(
feats,
feats_lengths,
beam_size=beam_size,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming)
hyps = [hyp.tolist() for hyp in hyps]
else:
raise ValueError(f"Not support decoding method: {decoding_method}")
res = [text_feature.defeaturize(hyp) for hyp in hyps]
return res
class U2STModel(U2STBaseModel):
def __init__(self, configs: dict):
vocab_size, encoder, decoder = U2STModel._init_from_config(configs)
if isinstance(decoder, Tuple):
st_decoder, asr_decoder, ctc = decoder
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
st_decoder=st_decoder,
decoder=asr_decoder,
ctc=ctc,
**configs['model_conf'])
else:
super().__init__(
vocab_size=vocab_size,
encoder=encoder,
st_decoder=decoder,
**configs['model_conf'])
@classmethod
def _init_from_config(cls, configs: dict):
"""init sub module for model.
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
int, nn.Layer, nn.Layer, nn.Layer: vocab size, encoder, decoder, ctc
"""
if configs['cmvn_file'] is not None:
mean, istd = load_cmvn(configs['cmvn_file'],
configs['cmvn_file_type'])
global_cmvn = GlobalCMVN(
paddle.to_tensor(mean, dtype=paddle.float),
paddle.to_tensor(istd, dtype=paddle.float))
else:
global_cmvn = None
input_dim = configs['input_dim']
vocab_size = configs['output_dim']
assert input_dim != 0, input_dim
assert vocab_size != 0, vocab_size
encoder_type = configs.get('encoder', 'transformer')
logger.info(f"U2 Encoder type: {encoder_type}")
if encoder_type == 'transformer':
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
elif encoder_type == 'conformer':
encoder = ConformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
else:
raise ValueError(f"not support encoder type:{encoder_type}")
st_decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
asr_weight = configs['model_conf']['asr_weight']
logger.info(f"ASR Joint Training Weight: {asr_weight}")
if asr_weight > 0.:
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = CTCDecoder(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
return vocab_size, encoder, (st_decoder, decoder, ctc)
else:
return vocab_size, encoder, st_decoder
@classmethod
def from_config(cls, configs: dict):
"""init model.
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
nn.Layer: U2STModel
"""
model = cls(configs)
return model
@classmethod
def from_pretrained(cls, dataloader, config, checkpoint_path):
"""Build a DeepSpeech2Model model from a pretrained model.
Args:
dataloader (paddle.io.DataLoader): not used.
config (yacs.config.CfgNode): model configs
checkpoint_path (Path or str): the path of pretrained model checkpoint, without extension name
Returns:
DeepSpeech2Model: The model built from pretrained result.
"""
config.defrost()
config.input_dim = dataloader.collate_fn.feature_size
config.output_dim = dataloader.collate_fn.vocab_size
config.freeze()
model = cls.from_config(config)
if checkpoint_path:
infos = checkpoint.load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model
class U2STInferModel(U2STModel):
def __init__(self, configs: dict):
super().__init__(configs)
def forward(self,
feats,
feats_lengths,
decoding_chunk_size=-1,
num_decoding_left_chunks=-1,
simulate_streaming=False):
"""export model function
Args:
feats (Tensor): [B, T, D]
feats_lengths (Tensor): [B]
Returns:
List[List[int]]: best path result
"""
return self.translate(
feats,
feats_lengths,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming)

@ -0,0 +1,53 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
"""
import numpy as np
import sacrebleu
__all__ = ['bleu', 'char_bleu']
def bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in word-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference length is zero.
"""
return sacrebleu.corpus_bleu(hypothesis, reference)
def char_bleu(hypothesis, reference):
"""Calculate BLEU. BLEU compares reference text and
hypothesis text in char-level using scarebleu.
:param reference: The reference sentences.
:type reference: list[list[str]]
:param hypothesis: The hypothesis sentence.
:type hypothesis: list[str]
:raises ValueError: If the reference number is zero.
"""
hypothesis =[' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis]
reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref ]for ref in reference ]
return sacrebleu.corpus_bleu(hypothesis, reference)
Loading…
Cancel
Save