Merge pull request #970 from zh794390558/align

[bugfix] ctc alignment
pull/972/head
Jackwaterveg 3 years ago committed by GitHub
commit aadaeb0fe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,7 +13,7 @@ bc flac jq vim tig tree pkg-config libsndfile1 libflac-dev libvorbis-dev libboos
```
build-essential cmake libboost-system-dev libboost-thread-dev libboost-program-options-dev libboost-test-dev libeigen3-dev zlib1g-dev libbz2-dev liblzma-dev gcc-5 g++-5
```
### The dependencies of sox:
```
@ -25,7 +25,7 @@ libvorbis-dev libmp3lame-dev libmad-ocaml-dev
```
kenlm
sox
sox
mfa
openblas
kaldi

@ -554,10 +554,11 @@ class U2Tester(U2Trainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list,
self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.

@ -527,10 +527,11 @@ class U2Tester(U2Trainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list,
self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.

@ -543,10 +543,10 @@ class U2STTester(U2STTrainer):
@paddle.no_grad()
def align(self):
ctc_utils.ctc_align(
self.model, self.align_loader, self.config.decoding.batch_size,
self.align_loader.collate_fn.stride_ms,
self.align_loader.collate_fn.vocab_list, self.args.result_file)
ctc_utils.ctc_align(self.config, self.model, self.align_loader,
self.config.decoding.batch_size,
self.config.collator.stride_ms, self.vocab_list,
self.args.result_file)
def load_inferspec(self):
"""infer model and input spec.

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wenet(https://github.com/wenet-e2e/wenet)
from pathlib import Path
from typing import List
import numpy as np
@ -139,26 +140,27 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
return output_alignment
def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
def ctc_align(config, model, dataloader, batch_size, stride_ms, token_dict,
result_file):
"""ctc alignment.
Args:
config (cfgNode): config
model (nn.Layer): U2 Model.
dataloader (io.DataLoader): dataloader.
batch_size (int): decoding batchsize.
stride_ms (int): audio feature stride in ms unit.
token_dict (List[str]): vocab list, e.g. ['blank', 'unk', 'a', 'b', '<eos>'].
result_file (str): alignment output file, e.g. xxx.align.
result_file (str): alignment output file, e.g. /path/to/xxx.align.
"""
if batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
assert result_file and result_file.endswith('.align')
model.eval()
# conv subsampling rate
subsample = utility.get_subsample(config)
logger.info(f"Align Total Examples: {len(dataloader.dataset)}")
with open(result_file, 'w') as fout:
@ -187,13 +189,11 @@ def ctc_align(model, dataloader, batch_size, stride_ms, token_dict,
logger.info(f"align tokens: {key[0]}, {align_segs}")
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(align_segs, subsample,
token_dict)
# write tier
align_output_path = Path(self.args.result_file).parent / "align"
align_output_path = Path(result_file).parent / "align"
align_output_path.mkdir(parents=True, exist_ok=True)
tier_path = align_output_path / (key[0] + ".tier")
with tier_path.open('w') as f:

@ -67,19 +67,16 @@ class LJSpeechCollector(object):
# Sort by text_len in descending order
texts = [
i
for i, _ in sorted(
i for i, _ in sorted(
zip(texts, text_lens), key=lambda x: x[1], reverse=True)
]
mels = [
i
for i, _ in sorted(
i for i, _ in sorted(
zip(mels, text_lens), key=lambda x: x[1], reverse=True)
]
mel_lens = [
i
for i, _ in sorted(
i for i, _ in sorted(
zip(mel_lens, text_lens), key=lambda x: x[1], reverse=True)
]

Loading…
Cancel
Save