[TTS]Fix attention bugs and sort VITS data with feats_lengths (#2770)

pull/2771/head
HuangLiangJie 3 years ago committed by GitHub
parent 6725bcd823
commit 2e51e0da90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -187,7 +187,7 @@ def main():
record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata.sort(key=itemgetter('feats_lengths'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata:

@ -166,7 +166,7 @@ def process_sentences(config,
if record:
results.append(record)
results.sort(key=itemgetter("utt_id"))
results.sort(key=itemgetter("feats_lengths"))
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
for item in results:
writer.write(item)

@ -24,13 +24,13 @@ import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.datasets.sampler import ErnieSATSampler
from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.models.vits import VITSEvaluator
from paddlespeech.t2s.models.vits import VITSUpdater
@ -107,12 +107,12 @@ def train_sp(args, config):
converters=converters, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_sampler = ErnieSATSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
dev_sampler = DistributedBatchSampler(
dev_sampler = ErnieSATSampler(
dev_dataset,
batch_size=config.batch_size,
shuffle=False,

@ -196,7 +196,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
return x
@ -299,7 +299,7 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
if self.zero_triu:
ones = paddle.ones((t1, t2))
x = x * paddle.tril(ones, t2 - 1)[None, None, :, :]
x = x * paddle.tril(ones, t2 - t1)[None, None, :, :]
return x

Loading…
Cancel
Save