Merge 03f6d9ed6e into d02ae35dc0
commit
3fe7bef2bf
Binary file not shown.
@ -0,0 +1,85 @@
|
||||
import numpy as np
|
||||
import paddle
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def load_wav(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return data, sampling_rate
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-05):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-05):
|
||||
return paddle.log(paddle.clip(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return paddle.exp(x=x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
output = dynamic_range_decompression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
|
||||
def mel_spectrogram(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
y = paddle.to_tensor(y.detach().cpu().numpy())
|
||||
if paddle.min(paddle.to_tensor(y)) < -1.0:
|
||||
print("min value is ", paddle.min(paddle.to_tensor(y)))
|
||||
if paddle.max(paddle.to_tensor(y)) > 1.0:
|
||||
print("max value is ", paddle.max(paddle.to_tensor(y)))
|
||||
global mel_basis, hann_window
|
||||
if f"{str(fmax)}_{str(y.place)}" not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[str(fmax) + "_" + str(y.place)] = (
|
||||
paddle.to_tensor(mel).float().to(y.place)
|
||||
)
|
||||
hann_window[str(y.place)] = paddle.audio.functional.get_window(
|
||||
win_length=win_size, dtype="float32", window="hann"
|
||||
).to(y.place)
|
||||
|
||||
y = paddle.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
window = paddle.load("/root/paddlejob/workspace/zhangjinghong/test/PaddleSpeech/matcha_window.pdparams").cuda()
|
||||
|
||||
stft = paddle.signal.stft(
|
||||
y.cuda(),
|
||||
n_fft=1920,
|
||||
hop_length=480,
|
||||
window=window
|
||||
)
|
||||
|
||||
spec = paddle.as_real(
|
||||
stft
|
||||
)
|
||||
spec = paddle.sqrt(spec.pow(2).sum(-1) + 1e-09)
|
||||
spec = paddle.matmul(mel_basis[str(fmax) + "_" + str(y.place)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
return spec
|
||||
@ -0,0 +1,130 @@
|
||||
from paddlespeech.t2s.models.CosyVoice.cosyvoice import CosyVoice2
|
||||
import sys
|
||||
from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from pathlib import Path
|
||||
import paddle
|
||||
import torch
|
||||
paddle.seed(42)
|
||||
from paddlespeech.t2s.frontend.CosyVoiceFrontEnd.frontend import CosyVoiceFrontEnd
|
||||
from paddlespeech.t2s.models.CosyVoice.llm import Qwen2LM,Qwen2Encoder
|
||||
from paddlespeech.t2s.models.CosyVoice.common import ras_sampling
|
||||
from paddlespeech.t2s.frontend.CosyVoiceFrontEnd.tokenizer import get_qwen_tokenizer
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
hyper_yaml_path = "/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle/cosyvoice2.yaml"
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
|
||||
# frontend = CosyVoiceFrontEnd(
|
||||
# lambda:get_qwen_tokenizer('/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN',skip_special_tokens=True),
|
||||
# configs['feat_extractor'],
|
||||
# "/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B/campplus.onnx",
|
||||
# "/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B/speech_tokenizer_v2.onnx",
|
||||
# "/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B/spk2info.pt",
|
||||
# configs['allowed_special']
|
||||
# )
|
||||
# prompt_wav = "../CosyVoice/gaoziyuan_10.wav"
|
||||
# tts_text = frontend.text_normalize("定位解决了云存储服务和 data loader 等多个环节的性能波动问题", split=True, text_frontend=True)
|
||||
# prompt_text = frontend.text_normalize('清晨的阳光透过树叶洒在地面上,微风轻轻吹过,带来花草的香气。街边的咖啡店刚开门,传来阵阵烘焙的香味,让人感到放松与愉快。', split=True, text_frontend=True)
|
||||
# model_input = frontend.frontend_zero_shot(tts_text, prompt_text, prompt_wav, 24000,'')
|
||||
# paddle.save(model_input,'model_input.pdparams')
|
||||
# # cosyvoice_model = CosyVoice2("../CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle")
|
||||
# model = AutoModelForCausalLM.from_pretrained('/root/paddlejob/workspace/zhangjinghong/test/pretrained/Qwen/Qwen2-0.5B')
|
||||
# print(type(model))
|
||||
# llm = Qwen2Encoder(model)
|
||||
# qwen_lm = Qwen2LM(896,896,6561,llm,ras_sampling)
|
||||
# state_dict = paddle.load("/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle/llm.pdparams")
|
||||
# qwen_lm.set_state_dict(state_dict)
|
||||
|
||||
|
||||
# new_dict = torch.load("/root/paddlejob/workspace/zhangjinghong/CosyVoice/data.pt")
|
||||
# text = new_dict['text']
|
||||
# text_len = new_dict['text_len']
|
||||
# prompt_text = new_dict['prompt_text']
|
||||
# prompt_text_len = new_dict['prompt_text_len']
|
||||
# prompt_speech_token = new_dict['prompt_speech_token']
|
||||
# prompt_speech_token_len = new_dict['prompt_speech_token_len']
|
||||
# embedding = new_dict['embedding']
|
||||
# uuid = new_dict['uuid']
|
||||
|
||||
|
||||
# text = model_input['text']
|
||||
# text_len = model_input['text_len']
|
||||
# prompt_text = model_input['prompt_text']
|
||||
# prompt_text_len = model_input['prompt_text_len']
|
||||
# prompt_speech_token = model_input['llm_prompt_speech_token']
|
||||
# prompt_speech_token_len = model_input['llm_prompt_speech_token_len']
|
||||
# embedding = model_input['llm_embedding']
|
||||
# uuid = new_dict['uuid']
|
||||
|
||||
# # 统一设备并转换为Paddle张量
|
||||
# device = paddle.CUDAPlace(0) # 使用GPU设备
|
||||
# text_tensor = paddle.to_tensor(text).cuda()
|
||||
# prompt_text_tensor = paddle.to_tensor(prompt_text).cuda()
|
||||
# prompt_speech_token_tensor = paddle.to_tensor(prompt_speech_token).cuda()
|
||||
# embedding_tensor = paddle.to_tensor(embedding, dtype='float32').cuda()
|
||||
# # 确保长度张量也统一设备并正确转换
|
||||
# text_len_tensor = text_len.cuda() if hasattr(text_len, 'cuda') else paddle.to_tensor(text_len).cuda()
|
||||
# prompt_text_len_tensor = prompt_text_len.cuda() if hasattr(prompt_text_len, 'cuda') else paddle.to_tensor(prompt_text_len).cuda()
|
||||
# prompt_speech_token_len_tensor = prompt_speech_token_len.cuda() if hasattr(prompt_speech_token_len, 'cuda') else paddle.to_tensor(prompt_speech_token_len).cuda()
|
||||
# token=[]
|
||||
# for i in qwen_lm.inference(text=text_tensor,
|
||||
# text_len=text_len_tensor,
|
||||
# prompt_text=prompt_text_tensor,
|
||||
# prompt_text_len=prompt_text_len_tensor,
|
||||
# prompt_speech_token=prompt_speech_token_tensor,
|
||||
# prompt_speech_token_len=prompt_speech_token_len_tensor,
|
||||
# embedding=embedding_tensor,
|
||||
# uuid=uuid):
|
||||
# token.append(i)
|
||||
# # print(text)
|
||||
# print("token: ",i)
|
||||
|
||||
############################################################################################################################
|
||||
|
||||
|
||||
flow = configs['flow']
|
||||
flow_state_dict = paddle.load("/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle/flow.pdparams")
|
||||
flow.set_state_dict(flow_state_dict)
|
||||
input_dict = torch.load("/root/paddlejob/workspace/zhangjinghong/test/CosyVoice/data.pt")
|
||||
flow.eval()
|
||||
tts_mel, _ = flow.inference(
|
||||
token = paddle.to_tensor(input_dict['token']),
|
||||
token_len = paddle.to_tensor(input_dict['token_len']),
|
||||
prompt_token = paddle.to_tensor(input_dict['prompt_token'].cpu().numpy(), dtype = 'int32'),
|
||||
prompt_token_len = paddle.to_tensor(input_dict['prompt_token_len'].cpu().numpy()),
|
||||
prompt_feat = paddle.to_tensor(input_dict['prompt_feat'].cpu().numpy()),
|
||||
prompt_feat_len = paddle.to_tensor(input_dict['prompt_feat_len'].cpu().numpy()),
|
||||
embedding = paddle.to_tensor(input_dict['embedding'].cpu().numpy()),
|
||||
streaming = input_dict['streaming'],
|
||||
finalize = input_dict['finalize']
|
||||
)
|
||||
paddle.save(tts_mel,"tts_mel.pdparams")
|
||||
|
||||
############################################################################################################################
|
||||
|
||||
from paddlespeech.t2s.models.hifigan.cosy_hifigan import HiFTGenerator
|
||||
from paddlespeech.t2s.models.hifigan.f0_predictor import ConvRNNF0Predictor
|
||||
hift_state_dict = paddle.load("/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle/hift.pdparams")
|
||||
input_mel = paddle.to_tensor(torch.load("../CosyVoice/tts_mel.pt").detach().cpu().numpy()).cuda()
|
||||
hift_cache_source= paddle.to_tensor(torch.load("../CosyVoice/hift_cache_source.pt").detach().cpu().numpy()).cuda()
|
||||
# hift_cache_source = paddle.zeros([1, 1, 0])
|
||||
hift_configs = configs['hift']
|
||||
f0_config = configs['f0_predictor']
|
||||
f0_predictor = ConvRNNF0Predictor(**f0_config)
|
||||
|
||||
hift_configs['f0_predictor'] = f0_predictor
|
||||
hift = HiFTGenerator(**hift_configs)
|
||||
hift.set_state_dict(hift_state_dict)
|
||||
# for k,v in hift.state_dict().items():
|
||||
# print(k,v.shape)
|
||||
# print("---"*40)
|
||||
for k,v in hift_state_dict.items():
|
||||
print(k,v.shape)
|
||||
tts_speech, tts_source = hift.inference(speech_feat=input_mel, cache_source=hift_cache_source)
|
||||
paddle.save(tts_speech,"speech.pdparams")
|
||||
# tts_speech,_ = hift.inference(input_dict['tts_mel'],input_dict['cache_source'])
|
||||
|
||||
import torchaudio
|
||||
import torch
|
||||
torchaudio.save("paddle.wav",torch.tensor(tts_speech.numpy()),24000)
|
||||
# sf.write("paddle.wav",tts_speech[0],24000)
|
||||
@ -0,0 +1,110 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import paddle
|
||||
import torchaudio
|
||||
import paddlespeech
|
||||
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
|
||||
def read_lists(list_file):
|
||||
lists = []
|
||||
with open(list_file, "r", encoding="utf8") as fin:
|
||||
for line in fin:
|
||||
lists.append(line.strip())
|
||||
return lists
|
||||
|
||||
|
||||
def read_json_lists(list_file):
|
||||
lists = read_lists(list_file)
|
||||
results = {}
|
||||
for fn in lists:
|
||||
with open(fn, "r", encoding="utf8") as fin:
|
||||
results.update(json.load(fin))
|
||||
return results
|
||||
|
||||
|
||||
def load_wav(wav, target_sr, min_sr=16000):
|
||||
speech, sample_rate = torchaudio.load(wav, backend="soundfile")
|
||||
speech = speech.mean(dim=0, keepdim=True)
|
||||
if sample_rate != target_sr:
|
||||
assert (
|
||||
sample_rate >= min_sr
|
||||
), "wav sample rate {} must be greater than {}".format(sample_rate, target_sr)
|
||||
speech = torchaudio.transforms.Resample(
|
||||
orig_freq=sample_rate, new_freq=target_sr
|
||||
)(speech)
|
||||
return speech
|
||||
|
||||
|
||||
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
import tensorrt as trt
|
||||
|
||||
logging.info("Converting onnx to trt...")
|
||||
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32)
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
with open(onnx_model, "rb") as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
raise ValueError("failed to parse {}".format(onnx_model))
|
||||
for i in range(len(trt_kwargs["input_names"])):
|
||||
profile.set_shape(
|
||||
trt_kwargs["input_names"][i],
|
||||
trt_kwargs["min_shape"][i],
|
||||
trt_kwargs["opt_shape"][i],
|
||||
trt_kwargs["max_shape"][i],
|
||||
)
|
||||
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
||||
for i in range(network.num_inputs):
|
||||
input_tensor = network.get_input(i)
|
||||
input_tensor.dtype = tensor_dtype
|
||||
for i in range(network.num_outputs):
|
||||
output_tensor = network.get_output(i)
|
||||
output_tensor.dtype = tensor_dtype
|
||||
config.add_optimization_profile(profile)
|
||||
engine_bytes = builder.build_serialized_network(network, config)
|
||||
with open(trt_model, "wb") as f:
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
def export_cosyvoice2_vllm(model, model_path, device):
|
||||
if os.path.exists(model_path):
|
||||
return
|
||||
dtype = paddle.bfloat16
|
||||
use_bias = True if model.llm_decoder.bias is not None else False
|
||||
model.llm.model.lm_head = model.llm_decoder
|
||||
embed_tokens = model.llm.model.model.embed_tokens
|
||||
model.llm.model.set_input_embeddings(model.speech_embedding)
|
||||
model.llm.model.to(device)
|
||||
model.llm.model.to(dtype)
|
||||
tmp_vocab_size = model.llm.model.config.vocab_size
|
||||
tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
|
||||
del model.llm.model.generation_config.eos_token_id
|
||||
del model.llm.model.config.bos_token_id
|
||||
del model.llm.model.config.eos_token_id
|
||||
model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
|
||||
model.llm.model.config.tie_word_embeddings = False
|
||||
model.llm.model.config.use_bias = use_bias
|
||||
model.llm.model.save_pretrained(model_path)
|
||||
if use_bias is True:
|
||||
os.system(
|
||||
"sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json".format(
|
||||
os.path.abspath(model_path)
|
||||
)
|
||||
)
|
||||
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
model.llm.model.set_input_embeddings(embed_tokens)
|
||||
@ -0,0 +1,747 @@
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
|
||||
import paddlespeech
|
||||
|
||||
__all__ = [
|
||||
"get_mel_banks",
|
||||
"inverse_mel_scale",
|
||||
"inverse_mel_scale_scalar",
|
||||
"mel_scale",
|
||||
"mel_scale_scalar",
|
||||
"spectrogram",
|
||||
"fbank",
|
||||
"mfcc",
|
||||
"vtln_warp_freq",
|
||||
"vtln_warp_mel_freq",
|
||||
]
|
||||
EPSILON = paddle.to_tensor(paddle.finfo(paddle.float32).eps)
|
||||
MILLISECONDS_TO_SECONDS = 0.001
|
||||
HAMMING = "hamming"
|
||||
HANNING = "hanning"
|
||||
POVEY = "povey"
|
||||
RECTANGULAR = "rectangular"
|
||||
BLACKMAN = "blackman"
|
||||
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
|
||||
|
||||
|
||||
def _get_epsilon(device, dtype):
|
||||
return EPSILON.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def _next_power_of_2(x: int) -> int:
|
||||
"""Returns the smallest power of 2 that is greater than x"""
|
||||
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_strided(
|
||||
waveform: paddle.Tensor, window_size: int, window_shift: int, snip_edges: bool
|
||||
) -> paddle.Tensor:
|
||||
"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
|
||||
representing how the window is shifted along the waveform. Each row is a frame.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of size ``num_samples``
|
||||
window_size (int): Frame length
|
||||
window_shift (int): Frame shift
|
||||
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends.
|
||||
|
||||
Returns:
|
||||
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
|
||||
"""
|
||||
assert waveform.dim() == 1
|
||||
num_samples = waveform.shape[0]
|
||||
strides = window_shift * waveform.strides[0], waveform.strides[0]
|
||||
if snip_edges:
|
||||
if num_samples < window_size:
|
||||
return paddle.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
|
||||
else:
|
||||
m = 1 + (num_samples - window_size) // window_shift
|
||||
else:
|
||||
reversed_waveform = paddle.flip(x=waveform, axis=[0])
|
||||
m = (num_samples + window_shift // 2) // window_shift
|
||||
pad = window_size // 2 - window_shift // 2
|
||||
pad_right = reversed_waveform
|
||||
if pad > 0:
|
||||
pad_left = reversed_waveform[-pad:]
|
||||
waveform = paddle.cat((pad_left, waveform, pad_right), axis=0)
|
||||
else:
|
||||
waveform = paddle.cat((waveform[-pad:], pad_right), axis=0)
|
||||
sizes = m, window_size
|
||||
return waveform.as_strided(shape=sizes, stride=strides)
|
||||
|
||||
|
||||
def _feature_window_function(
|
||||
window_type: str,
|
||||
window_size: int,
|
||||
blackman_coeff: float,
|
||||
device: paddle.device,
|
||||
dtype: int,
|
||||
) -> paddle.Tensor:
|
||||
"""Returns a window function with the given type and size"""
|
||||
if window_type == HANNING:
|
||||
return paddle.audio.functional.get_window(
|
||||
win_length=window_size, fftbins=False, dtype=dtype, window="hann"
|
||||
)
|
||||
elif window_type == HAMMING:
|
||||
return paddle.hamming_window(
|
||||
window_size,
|
||||
periodic=False,
|
||||
alpha=0.54,
|
||||
beta=0.46,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
elif window_type == POVEY:
|
||||
return paddle.audio.functional.get_window(
|
||||
win_length=window_size, fftbins=False, dtype=dtype, window="hann"
|
||||
).pow(0.85)
|
||||
elif window_type == RECTANGULAR:
|
||||
return paddle.ones(window_size, device=device, dtype=dtype)
|
||||
elif window_type == BLACKMAN:
|
||||
a = 2 * math.pi / (window_size - 1)
|
||||
window_function = paddle.arange(window_size, device=device, dtype=dtype)
|
||||
return (
|
||||
blackman_coeff
|
||||
- 0.5 * paddle.cos(a * window_function)
|
||||
+ (0.5 - blackman_coeff) * paddle.cos(2 * a * window_function)
|
||||
).to(device=device, dtype=dtype)
|
||||
else:
|
||||
raise Exception("Invalid window type " + window_type)
|
||||
|
||||
|
||||
def _get_log_energy(
|
||||
strided_input: paddle.Tensor, epsilon: paddle.Tensor, energy_floor: float
|
||||
) -> paddle.Tensor:
|
||||
"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
||||
place, dtype = strided_input.place, strided_input.dtype
|
||||
log_energy = paddle.maximum(strided_input.pow(2).sum(axis=1), epsilon).log()
|
||||
if energy_floor == 0.0:
|
||||
return log_energy
|
||||
return paddle.maximum(
|
||||
log_energy, paddle.to_tensor(math.log(energy_floor), place=place, dtype=dtype)
|
||||
)
|
||||
|
||||
|
||||
def _get_waveform_and_window_properties(
|
||||
waveform: paddle.Tensor,
|
||||
channel: int,
|
||||
sample_frequency: float,
|
||||
frame_shift: float,
|
||||
frame_length: float,
|
||||
round_to_power_of_two: bool,
|
||||
preemphasis_coefficient: float,
|
||||
) -> Tuple[paddle.Tensor, int, int, int]:
|
||||
"""Gets the waveform and window properties"""
|
||||
channel = max(channel, 0)
|
||||
assert channel < waveform.shape[0], "Invalid channel {} for size {}".format(
|
||||
channel, waveform.shape[0]
|
||||
)
|
||||
waveform = waveform[channel, :]
|
||||
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
|
||||
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
|
||||
padded_window_size = (
|
||||
_next_power_of_2(window_size) if round_to_power_of_two else window_size
|
||||
)
|
||||
assert (
|
||||
2 <= window_size <= len(waveform)
|
||||
), "choose a window size {} that is [2, {}]".format(window_size, len(waveform))
|
||||
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
||||
assert (
|
||||
padded_window_size % 2 == 0
|
||||
), "the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
|
||||
assert (
|
||||
0.0 <= preemphasis_coefficient <= 1.0
|
||||
), "`preemphasis_coefficient` must be between [0,1]"
|
||||
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
||||
return waveform, window_shift, window_size, padded_window_size
|
||||
|
||||
|
||||
def _get_window(
|
||||
waveform: paddle.Tensor,
|
||||
padded_window_size: int,
|
||||
window_size: int,
|
||||
window_shift: int,
|
||||
window_type: str,
|
||||
blackman_coeff: float,
|
||||
snip_edges: bool,
|
||||
raw_energy: bool,
|
||||
energy_floor: float,
|
||||
dither: float,
|
||||
remove_dc_offset: bool,
|
||||
preemphasis_coefficient: float,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Gets a window and its log energy
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
|
||||
"""
|
||||
|
||||
place, dtype = waveform.place, waveform.dtype
|
||||
epsilon = _get_epsilon(place, dtype)
|
||||
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
|
||||
if dither != 0.0:
|
||||
rand_gauss = paddle.randn(strided_input.shape, place=place, dtype=dtype)
|
||||
strided_input = strided_input + rand_gauss * dither
|
||||
if remove_dc_offset:
|
||||
row_means = paddle.mean(strided_input, axis=1).unsqueeze(1)
|
||||
strided_input = strided_input - row_means
|
||||
if raw_energy:
|
||||
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)
|
||||
if preemphasis_coefficient != 0.0:
|
||||
offset_strided_input = paddle.nn.functional.pad(
|
||||
strided_input.unsqueeze(0), (1, 0), mode="replicate"
|
||||
).squeeze(0)
|
||||
strided_input = (
|
||||
strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
|
||||
)
|
||||
window_function = _feature_window_function(
|
||||
window_type, window_size, blackman_coeff, place, dtype
|
||||
).unsqueeze(0)
|
||||
strided_input = strided_input * window_function
|
||||
if padded_window_size != window_size:
|
||||
padding_right = padded_window_size - window_size
|
||||
strided_input = paddle.nn.functional.pad(
|
||||
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
|
||||
).squeeze(0)
|
||||
if not raw_energy:
|
||||
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor)
|
||||
return strided_input, signal_log_energy
|
||||
|
||||
|
||||
def _subtract_column_mean(tensor: paddle.Tensor, subtract_mean: bool) -> paddle.Tensor:
|
||||
if subtract_mean:
|
||||
col_means = paddle.mean(tensor, axis=0).unsqueeze(0)
|
||||
tensor = tensor - col_means
|
||||
return tensor
|
||||
|
||||
|
||||
def spectrogram(
|
||||
waveform: paddle.Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
channel: int = -1,
|
||||
dither: float = 0.0,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
min_duration: float = 0.0,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
sample_frequency: float = 16000.0,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
window_type: str = POVEY,
|
||||
) -> paddle.Tensor:
|
||||
"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
|
||||
compute-spectrogram-feats.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. (Default: ``True``)
|
||||
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||
specified there) (Default: ``16000.0``)
|
||||
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||
it this way. (Default: ``False``)
|
||||
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||
(Default: ``'povey'``)
|
||||
|
||||
Returns:
|
||||
Tensor: A spectrogram identical to what Kaldi would output. The shape is
|
||||
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
epsilon = _get_epsilon(device, dtype)
|
||||
(
|
||||
waveform,
|
||||
window_shift,
|
||||
window_size,
|
||||
padded_window_size,
|
||||
) = _get_waveform_and_window_properties(
|
||||
waveform,
|
||||
channel,
|
||||
sample_frequency,
|
||||
frame_shift,
|
||||
frame_length,
|
||||
round_to_power_of_two,
|
||||
preemphasis_coefficient,
|
||||
)
|
||||
if len(waveform) < min_duration * sample_frequency:
|
||||
return paddle.empty(0)
|
||||
strided_input, signal_log_energy = _get_window(
|
||||
waveform,
|
||||
padded_window_size,
|
||||
window_size,
|
||||
window_shift,
|
||||
window_type,
|
||||
blackman_coeff,
|
||||
snip_edges,
|
||||
raw_energy,
|
||||
energy_floor,
|
||||
dither,
|
||||
remove_dc_offset,
|
||||
preemphasis_coefficient,
|
||||
)
|
||||
fft = paddle.fft.rfft(strided_input)
|
||||
power_spectrum = paddle.maximum(fft.abs().pow(2.0), epsilon).log()
|
||||
power_spectrum[:, 0] = signal_log_energy
|
||||
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
||||
return power_spectrum
|
||||
|
||||
|
||||
def inverse_mel_scale_scalar(mel_freq: float) -> float:
|
||||
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
||||
|
||||
|
||||
def inverse_mel_scale(mel_freq: paddle.Tensor) -> paddle.Tensor:
|
||||
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
||||
|
||||
|
||||
def mel_scale_scalar(freq: float) -> float:
|
||||
return 1127.0 * math.log(1.0 + freq / 700.0)
|
||||
|
||||
|
||||
def mel_scale(freq: paddle.Tensor) -> paddle.Tensor:
|
||||
return 1127.0 * (1.0 + freq / 700.0).log()
|
||||
|
||||
|
||||
def vtln_warp_freq(
|
||||
vtln_low_cutoff: float,
|
||||
vtln_high_cutoff: float,
|
||||
low_freq: float,
|
||||
high_freq: float,
|
||||
vtln_warp_factor: float,
|
||||
freq: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""This computes a VTLN warping function that is not the same as HTK's one,
|
||||
but has similar inputs (this function has the advantage of never producing
|
||||
empty bins).
|
||||
|
||||
This function computes a warp function F(freq), defined between low_freq
|
||||
and high_freq inclusive, with the following properties:
|
||||
F(low_freq) == low_freq
|
||||
F(high_freq) == high_freq
|
||||
The function is continuous and piecewise linear with two inflection
|
||||
points.
|
||||
The lower inflection point (measured in terms of the unwarped
|
||||
frequency) is at frequency l, determined as described below.
|
||||
The higher inflection point is at a frequency h, determined as
|
||||
described below.
|
||||
If l <= f <= h, then F(f) = f/vtln_warp_factor.
|
||||
If the higher inflection point (measured in terms of the unwarped
|
||||
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
|
||||
Since (by the last point) F(h) == h/vtln_warp_factor, then
|
||||
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
|
||||
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
|
||||
= vtln_high_cutoff * min(1, vtln_warp_factor).
|
||||
If the lower inflection point (measured in terms of the unwarped
|
||||
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
|
||||
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
|
||||
= vtln_low_cutoff * max(1, vtln_warp_factor)
|
||||
Args:
|
||||
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
||||
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
||||
low_freq (float): Lower frequency cutoffs in mel computation
|
||||
high_freq (float): Upper frequency cutoffs in mel computation
|
||||
vtln_warp_factor (float): Vtln warp factor
|
||||
freq (Tensor): given frequency in Hz
|
||||
|
||||
Returns:
|
||||
Tensor: Freq after vtln warp
|
||||
"""
|
||||
assert (
|
||||
vtln_low_cutoff > low_freq
|
||||
), "be sure to set the vtln_low option higher than low_freq"
|
||||
assert (
|
||||
vtln_high_cutoff < high_freq
|
||||
), "be sure to set the vtln_high option lower than high_freq [or negative]"
|
||||
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
||||
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
||||
scale = 1.0 / vtln_warp_factor
|
||||
Fl = scale * l
|
||||
Fh = scale * h
|
||||
assert l > low_freq and h < high_freq
|
||||
scale_left = (Fl - low_freq) / (l - low_freq)
|
||||
scale_right = (high_freq - Fh) / (high_freq - h)
|
||||
res = paddle.empty_like(freq)
|
||||
outside_low_high_freq = paddle.lt(freq, low_freq) | paddle.gt(freq, high_freq)
|
||||
before_l = paddle.lt(freq, l)
|
||||
before_h = paddle.lt(freq, h)
|
||||
after_h = paddle.ge(freq, h)
|
||||
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
||||
res[before_h] = scale * freq[before_h]
|
||||
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
||||
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
||||
return res
|
||||
|
||||
|
||||
def vtln_warp_mel_freq(
|
||||
vtln_low_cutoff: float,
|
||||
vtln_high_cutoff: float,
|
||||
low_freq,
|
||||
high_freq: float,
|
||||
vtln_warp_factor: float,
|
||||
mel_freq: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Args:
|
||||
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
||||
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
||||
low_freq (float): Lower frequency cutoffs in mel computation
|
||||
high_freq (float): Upper frequency cutoffs in mel computation
|
||||
vtln_warp_factor (float): Vtln warp factor
|
||||
mel_freq (Tensor): Given frequency in Mel
|
||||
|
||||
Returns:
|
||||
Tensor: ``mel_freq`` after vtln warp
|
||||
"""
|
||||
return mel_scale(
|
||||
vtln_warp_freq(
|
||||
vtln_low_cutoff,
|
||||
vtln_high_cutoff,
|
||||
low_freq,
|
||||
high_freq,
|
||||
vtln_warp_factor,
|
||||
inverse_mel_scale(mel_freq),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_mel_banks(
|
||||
num_bins: int,
|
||||
window_length_padded: int,
|
||||
sample_freq: float,
|
||||
low_freq: float,
|
||||
high_freq: float,
|
||||
vtln_low: float,
|
||||
vtln_high: float,
|
||||
vtln_warp_factor: float,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
(Tensor, Tensor): The tuple consists of ``bins`` (which is
|
||||
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
|
||||
center frequencies of bins of size (``num_bins``)).
|
||||
"""
|
||||
assert num_bins > 3, "Must have at least 3 mel bins"
|
||||
assert window_length_padded % 2 == 0
|
||||
num_fft_bins = window_length_padded / 2
|
||||
nyquist = 0.5 * sample_freq
|
||||
if high_freq <= 0.0:
|
||||
high_freq += nyquist
|
||||
assert (
|
||||
0.0 <= low_freq < nyquist
|
||||
and 0.0 < high_freq <= nyquist
|
||||
and low_freq < high_freq
|
||||
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(
|
||||
low_freq, high_freq, nyquist
|
||||
)
|
||||
fft_bin_width = sample_freq / window_length_padded
|
||||
mel_low_freq = mel_scale_scalar(low_freq)
|
||||
mel_high_freq = mel_scale_scalar(high_freq)
|
||||
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||
if vtln_high < 0.0:
|
||||
vtln_high += nyquist
|
||||
assert (
|
||||
vtln_warp_factor == 1.0
|
||||
or low_freq < vtln_low < high_freq
|
||||
and 0.0 < vtln_high < high_freq
|
||||
and vtln_low < vtln_high
|
||||
), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
|
||||
vtln_low, vtln_high, low_freq, high_freq
|
||||
)
|
||||
bin = paddle.arange(num_bins).unsqueeze(1)
|
||||
left_mel = mel_low_freq + bin * mel_freq_delta
|
||||
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta
|
||||
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta
|
||||
if vtln_warp_factor != 1.0:
|
||||
left_mel = vtln_warp_mel_freq(
|
||||
vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel
|
||||
)
|
||||
center_mel = vtln_warp_mel_freq(
|
||||
vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel
|
||||
)
|
||||
right_mel = vtln_warp_mel_freq(
|
||||
vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel
|
||||
)
|
||||
center_freqs = inverse_mel_scale(center_mel)
|
||||
mel = mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0)
|
||||
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
||||
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
||||
if vtln_warp_factor == 1.0:
|
||||
bins = paddle.maximum(
|
||||
paddle.zeros(1), paddle.minimum(up_slope, down_slope)
|
||||
)
|
||||
else:
|
||||
bins = paddle.zeros_like(up_slope)
|
||||
up_idx = paddle.gt(mel, left_mel) & paddle.le(mel, center_mel)
|
||||
down_idx = paddle.gt(mel, center_mel) & paddle.lt(mel, right_mel)
|
||||
bins[up_idx] = up_slope[up_idx]
|
||||
bins[down_idx] = down_slope[down_idx]
|
||||
return bins, center_freqs
|
||||
|
||||
|
||||
def fbank(
|
||||
waveform: paddle.Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
channel: int = -1,
|
||||
dither: float = 0.0,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
high_freq: float = 0.0,
|
||||
htk_compat: bool = False,
|
||||
low_freq: float = 20.0,
|
||||
min_duration: float = 0.0,
|
||||
num_mel_bins: int = 23,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
sample_frequency: float = 16000.0,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
use_energy: bool = False,
|
||||
use_log_fbank: bool = True,
|
||||
use_power: bool = True,
|
||||
vtln_high: float = -500.0,
|
||||
vtln_low: float = 100.0,
|
||||
vtln_warp: float = 1.0,
|
||||
window_type: str = POVEY,
|
||||
) -> paddle.Tensor:
|
||||
device, dtype = waveform.place, waveform.dtype
|
||||
(
|
||||
waveform,
|
||||
window_shift,
|
||||
window_size,
|
||||
padded_window_size,
|
||||
) = _get_waveform_and_window_properties(
|
||||
waveform,
|
||||
channel,
|
||||
sample_frequency,
|
||||
frame_shift,
|
||||
frame_length,
|
||||
round_to_power_of_two,
|
||||
preemphasis_coefficient,
|
||||
)
|
||||
if len(waveform) < min_duration * sample_frequency:
|
||||
return paddle.empty(0, place=place, dtype=dtype)
|
||||
strided_input, signal_log_energy = _get_window(
|
||||
waveform,
|
||||
padded_window_size,
|
||||
window_size,
|
||||
window_shift,
|
||||
window_type,
|
||||
blackman_coeff,
|
||||
snip_edges,
|
||||
raw_energy,
|
||||
energy_floor,
|
||||
dither,
|
||||
remove_dc_offset,
|
||||
preemphasis_coefficient,
|
||||
)
|
||||
spectrum = paddle.fft.rfft(strided_input).abs()
|
||||
if use_power:
|
||||
spectrum = spectrum.pow(2.0)
|
||||
mel_energies, _ = get_mel_banks(
|
||||
num_mel_bins,
|
||||
padded_window_size,
|
||||
sample_frequency,
|
||||
low_freq,
|
||||
high_freq,
|
||||
vtln_low,
|
||||
vtln_high,
|
||||
vtln_warp,
|
||||
)
|
||||
mel_energies = mel_energies.to(device=device, dtype=dtype)
|
||||
mel_energies = paddle.nn.functional.pad(
|
||||
mel_energies, (0, 1), mode="constant", value=0
|
||||
)
|
||||
mel_energies = paddle.mm(input=spectrum, mat2=mel_energies.T)
|
||||
if use_log_fbank:
|
||||
mel_energies = paddle.maximum(
|
||||
mel_energies, _get_epsilon(device, dtype)
|
||||
).log()
|
||||
if use_energy:
|
||||
signal_log_energy = signal_log_energy.unsqueeze(1)
|
||||
if htk_compat:
|
||||
mel_energies = paddle.cat((mel_energies, signal_log_energy), axis=1)
|
||||
else:
|
||||
mel_energies = paddle.cat((signal_log_energy, mel_energies), axis=1)
|
||||
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
||||
return mel_energies
|
||||
|
||||
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> paddle.Tensor:
|
||||
dct_matrix = paddle.audio.functional.create_dct(
|
||||
n_mfcc=num_mel_bins,
|
||||
n_mels=num_mel_bins,
|
||||
norm='ortho'
|
||||
)
|
||||
first_col = paddle.full(
|
||||
shape=[num_mel_bins, 1],
|
||||
fill_value=math.sqrt(1 / float(num_mel_bins))
|
||||
)
|
||||
if num_ceps > 1:
|
||||
other_cols = dct_matrix[:, 1:num_ceps]
|
||||
dct_matrix = paddle.concat([first_col, other_cols], axis=1)
|
||||
else:
|
||||
dct_matrix = first_col
|
||||
return dct_matrix
|
||||
|
||||
|
||||
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> paddle.Tensor:
|
||||
i = paddle.arange(num_ceps)
|
||||
return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i / cepstral_lifter)
|
||||
|
||||
|
||||
def mfcc(
|
||||
waveform: paddle.Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
cepstral_lifter: float = 22.0,
|
||||
channel: int = -1,
|
||||
dither: float = 0.0,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
high_freq: float = 0.0,
|
||||
htk_compat: bool = False,
|
||||
low_freq: float = 20.0,
|
||||
num_ceps: int = 13,
|
||||
min_duration: float = 0.0,
|
||||
num_mel_bins: int = 23,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
sample_frequency: float = 16000.0,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
use_energy: bool = False,
|
||||
vtln_high: float = -500.0,
|
||||
vtln_low: float = 100.0,
|
||||
vtln_warp: float = 1.0,
|
||||
window_type: str = POVEY,
|
||||
) -> paddle.Tensor:
|
||||
"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
|
||||
compute-mfcc-feats.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
|
||||
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
||||
(Default: ``0.0``)
|
||||
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
|
||||
features (need to change other parameters). (Default: ``False``)
|
||||
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
||||
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
|
||||
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
||||
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. (Default: ``True``)
|
||||
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||
specified there) (Default: ``16000.0``)
|
||||
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||
it this way. (Default: ``False``)
|
||||
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
||||
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
||||
negative, offset from high-mel-freq (Default: ``-500.0``)
|
||||
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
||||
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
||||
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||
(Default: ``"povey"``)
|
||||
|
||||
Returns:
|
||||
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
|
||||
where m is calculated in _get_strided
|
||||
"""
|
||||
assert (
|
||||
num_ceps <= num_mel_bins
|
||||
), "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (
|
||||
num_ceps,
|
||||
num_mel_bins,
|
||||
)
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
feature = fbank(
|
||||
waveform=waveform,
|
||||
blackman_coeff=blackman_coeff,
|
||||
channel=channel,
|
||||
dither=dither,
|
||||
energy_floor=energy_floor,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
high_freq=high_freq,
|
||||
htk_compat=htk_compat,
|
||||
low_freq=low_freq,
|
||||
min_duration=min_duration,
|
||||
num_mel_bins=num_mel_bins,
|
||||
preemphasis_coefficient=preemphasis_coefficient,
|
||||
raw_energy=raw_energy,
|
||||
remove_dc_offset=remove_dc_offset,
|
||||
round_to_power_of_two=round_to_power_of_two,
|
||||
sample_frequency=sample_frequency,
|
||||
snip_edges=snip_edges,
|
||||
subtract_mean=False,
|
||||
use_energy=use_energy,
|
||||
use_log_fbank=True,
|
||||
use_power=True,
|
||||
vtln_high=vtln_high,
|
||||
vtln_low=vtln_low,
|
||||
vtln_warp=vtln_warp,
|
||||
window_type=window_type,
|
||||
)
|
||||
if use_energy:
|
||||
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
|
||||
mel_offset = int(not htk_compat)
|
||||
feature = feature[:, mel_offset : num_mel_bins + mel_offset]
|
||||
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
|
||||
feature = feature.matmul(dct_matrix)
|
||||
if cepstral_lifter != 0.0:
|
||||
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
|
||||
feature *= lifter_coeffs.to(device=device, dtype=dtype)
|
||||
if use_energy:
|
||||
feature[:, 0] = signal_log_energy
|
||||
if htk_compat:
|
||||
energy = feature[:, 0].unsqueeze(1)
|
||||
feature = feature[:, 1:]
|
||||
if not use_energy:
|
||||
energy *= math.sqrt(2)
|
||||
feature = paddle.cat((feature, energy), axis=1)
|
||||
feature = _subtract_column_mean(feature, subtract_mean)
|
||||
return feature
|
||||
@ -0,0 +1,578 @@
|
||||
import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
import tiktoken
|
||||
from whisper.tokenizer import Tokenizer
|
||||
from paddlenlp.transformers import AutoTokenizer
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
"minnan": "minnan",
|
||||
"wuyu": "wuyu",
|
||||
"dialect": "dialect",
|
||||
"zh/en": "zh/en",
|
||||
"en/zh": "en/zh",
|
||||
}
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
"mandarin": "zh",
|
||||
}
|
||||
AUDIO_EVENT = {
|
||||
"ASR": "ASR",
|
||||
"AED": "AED",
|
||||
"SER": "SER",
|
||||
"Speech": "Speech",
|
||||
"/Speech": "/Speech",
|
||||
"BGM": "BGM",
|
||||
"/BGM": "/BGM",
|
||||
"Laughter": "Laughter",
|
||||
"/Laughter": "/Laughter",
|
||||
"Applause": "Applause",
|
||||
"/Applause": "/Applause",
|
||||
}
|
||||
EMOTION = {"HAPPY": "HAPPY", "SAD": "SAD", "ANGRY": "ANGRY", "NEUTRAL": "NEUTRAL"}
|
||||
TTS_Vocal_Token = {
|
||||
"TTS/B": "TTS/B",
|
||||
"TTS/O": "TTS/O",
|
||||
"TTS/Q": "TTS/Q",
|
||||
"TTS/A": "TTS/A",
|
||||
"TTS/CO": "TTS/CO",
|
||||
"TTS/CL": "TTS/CL",
|
||||
"TTS/H": "TTS/H",
|
||||
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)},
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
special_tokens = {}
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
||||
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)],
|
||||
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())],
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
for token in specials:
|
||||
special_tokens[token] = n_vocab
|
||||
n_vocab += 1
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str="'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
num_languages: int = 99,
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None,
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
if language not in LANGUAGES:
|
||||
if language in TO_LANGUAGE_CODE:
|
||||
language = TO_LANGUAGE_CODE[language]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
if multilingual:
|
||||
encoding_name = "multilingual_zh_ja_yue_char_del"
|
||||
language = language or "en"
|
||||
task = task or "transcribe"
|
||||
else:
|
||||
encoding_name = "gpt2"
|
||||
language = None
|
||||
task = None
|
||||
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||
return Tokenizer(
|
||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||
)
|
||||
|
||||
|
||||
class CosyVoice2Tokenizer:
|
||||
def __init__(self, token_path, skip_special_tokens=True):
|
||||
super().__init__()
|
||||
special_tokens = {
|
||||
"eos_token": "<|endoftext|>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|endofprompt|>",
|
||||
"[breath]",
|
||||
"<strong>",
|
||||
"</strong>",
|
||||
"[noise]",
|
||||
"[laughter]",
|
||||
"[cough]",
|
||||
"[clucking]",
|
||||
"[accent]",
|
||||
"[quick_breath]",
|
||||
"<laughter>",
|
||||
"</laughter>",
|
||||
"[hissing]",
|
||||
"[sigh]",
|
||||
"[vocalized-noise]",
|
||||
"[lipsmack]",
|
||||
"[mn]",
|
||||
],
|
||||
}
|
||||
self.special_tokens = special_tokens
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||||
self.tokenizer.add_special_tokens(special_tokens)
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
tokens = self.tokenizer(text)
|
||||
tokens = tokens["input_ids"][0]
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
tokens = paddle.tensor(tokens, dtype=paddle.int64)
|
||||
text = self.tokenizer.batch_decode(
|
||||
[tokens], skip_special_tokens=self.skip_special_tokens
|
||||
)[0]
|
||||
return text
|
||||
|
||||
|
||||
class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
|
||||
def __init__(self, token_path, skip_special_tokens=True):
|
||||
special_tokens = {
|
||||
"eos_token": "<|endoftext|>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|endofprompt|>",
|
||||
"[breath]",
|
||||
"<strong>",
|
||||
"</strong>",
|
||||
"[noise]",
|
||||
"[laughter]",
|
||||
"[cough]",
|
||||
"[clucking]",
|
||||
"[accent]",
|
||||
"[quick_breath]",
|
||||
"<laughter>",
|
||||
"</laughter>",
|
||||
"[hissing]",
|
||||
"[sigh]",
|
||||
"[vocalized-noise]",
|
||||
"[lipsmack]",
|
||||
"[mn]",
|
||||
"<|endofsystem|>",
|
||||
"[AA]",
|
||||
"[AA0]",
|
||||
"[AA1]",
|
||||
"[AA2]",
|
||||
"[AE]",
|
||||
"[AE0]",
|
||||
"[AE1]",
|
||||
"[AE2]",
|
||||
"[AH]",
|
||||
"[AH0]",
|
||||
"[AH1]",
|
||||
"[AH2]",
|
||||
"[AO]",
|
||||
"[AO0]",
|
||||
"[AO1]",
|
||||
"[AO2]",
|
||||
"[AW]",
|
||||
"[AW0]",
|
||||
"[AW1]",
|
||||
"[AW2]",
|
||||
"[AY]",
|
||||
"[AY0]",
|
||||
"[AY1]",
|
||||
"[AY2]",
|
||||
"[B]",
|
||||
"[CH]",
|
||||
"[D]",
|
||||
"[DH]",
|
||||
"[EH]",
|
||||
"[EH0]",
|
||||
"[EH1]",
|
||||
"[EH2]",
|
||||
"[ER]",
|
||||
"[ER0]",
|
||||
"[ER1]",
|
||||
"[ER2]",
|
||||
"[EY]",
|
||||
"[EY0]",
|
||||
"[EY1]",
|
||||
"[EY2]",
|
||||
"[F]",
|
||||
"[G]",
|
||||
"[HH]",
|
||||
"[IH]",
|
||||
"[IH0]",
|
||||
"[IH1]",
|
||||
"[IH2]",
|
||||
"[IY]",
|
||||
"[IY0]",
|
||||
"[IY1]",
|
||||
"[IY2]",
|
||||
"[JH]",
|
||||
"[K]",
|
||||
"[L]",
|
||||
"[M]",
|
||||
"[N]",
|
||||
"[NG]",
|
||||
"[OW]",
|
||||
"[OW0]",
|
||||
"[OW1]",
|
||||
"[OW2]",
|
||||
"[OY]",
|
||||
"[OY0]",
|
||||
"[OY1]",
|
||||
"[OY2]",
|
||||
"[P]",
|
||||
"[R]",
|
||||
"[S]",
|
||||
"[SH]",
|
||||
"[T]",
|
||||
"[TH]",
|
||||
"[UH]",
|
||||
"[UH0]",
|
||||
"[UH1]",
|
||||
"[UH2]",
|
||||
"[UW]",
|
||||
"[UW0]",
|
||||
"[UW1]",
|
||||
"[UW2]",
|
||||
"[V]",
|
||||
"[W]",
|
||||
"[Y]",
|
||||
"[Z]",
|
||||
"[ZH]",
|
||||
"[a]",
|
||||
"[ai]",
|
||||
"[an]",
|
||||
"[ang]",
|
||||
"[ao]",
|
||||
"[b]",
|
||||
"[c]",
|
||||
"[ch]",
|
||||
"[d]",
|
||||
"[e]",
|
||||
"[ei]",
|
||||
"[en]",
|
||||
"[eng]",
|
||||
"[f]",
|
||||
"[g]",
|
||||
"[h]",
|
||||
"[i]",
|
||||
"[ian]",
|
||||
"[in]",
|
||||
"[ing]",
|
||||
"[iu]",
|
||||
"[ià]",
|
||||
"[iàn]",
|
||||
"[iàng]",
|
||||
"[iào]",
|
||||
"[iá]",
|
||||
"[ián]",
|
||||
"[iáng]",
|
||||
"[iáo]",
|
||||
"[iè]",
|
||||
"[ié]",
|
||||
"[iòng]",
|
||||
"[ióng]",
|
||||
"[iù]",
|
||||
"[iú]",
|
||||
"[iā]",
|
||||
"[iān]",
|
||||
"[iāng]",
|
||||
"[iāo]",
|
||||
"[iē]",
|
||||
"[iě]",
|
||||
"[iōng]",
|
||||
"[iū]",
|
||||
"[iǎ]",
|
||||
"[iǎn]",
|
||||
"[iǎng]",
|
||||
"[iǎo]",
|
||||
"[iǒng]",
|
||||
"[iǔ]",
|
||||
"[j]",
|
||||
"[k]",
|
||||
"[l]",
|
||||
"[m]",
|
||||
"[n]",
|
||||
"[o]",
|
||||
"[ong]",
|
||||
"[ou]",
|
||||
"[p]",
|
||||
"[q]",
|
||||
"[r]",
|
||||
"[s]",
|
||||
"[sh]",
|
||||
"[t]",
|
||||
"[u]",
|
||||
"[uang]",
|
||||
"[ue]",
|
||||
"[un]",
|
||||
"[uo]",
|
||||
"[uà]",
|
||||
"[uài]",
|
||||
"[uàn]",
|
||||
"[uàng]",
|
||||
"[uá]",
|
||||
"[uái]",
|
||||
"[uán]",
|
||||
"[uáng]",
|
||||
"[uè]",
|
||||
"[ué]",
|
||||
"[uì]",
|
||||
"[uí]",
|
||||
"[uò]",
|
||||
"[uó]",
|
||||
"[uā]",
|
||||
"[uāi]",
|
||||
"[uān]",
|
||||
"[uāng]",
|
||||
"[uē]",
|
||||
"[uě]",
|
||||
"[uī]",
|
||||
"[uō]",
|
||||
"[uǎ]",
|
||||
"[uǎi]",
|
||||
"[uǎn]",
|
||||
"[uǎng]",
|
||||
"[uǐ]",
|
||||
"[uǒ]",
|
||||
"[vè]",
|
||||
"[w]",
|
||||
"[x]",
|
||||
"[y]",
|
||||
"[z]",
|
||||
"[zh]",
|
||||
"[à]",
|
||||
"[ài]",
|
||||
"[àn]",
|
||||
"[àng]",
|
||||
"[ào]",
|
||||
"[á]",
|
||||
"[ái]",
|
||||
"[án]",
|
||||
"[áng]",
|
||||
"[áo]",
|
||||
"[è]",
|
||||
"[èi]",
|
||||
"[èn]",
|
||||
"[èng]",
|
||||
"[èr]",
|
||||
"[é]",
|
||||
"[éi]",
|
||||
"[én]",
|
||||
"[éng]",
|
||||
"[ér]",
|
||||
"[ì]",
|
||||
"[ìn]",
|
||||
"[ìng]",
|
||||
"[í]",
|
||||
"[ín]",
|
||||
"[íng]",
|
||||
"[ò]",
|
||||
"[òng]",
|
||||
"[òu]",
|
||||
"[ó]",
|
||||
"[óng]",
|
||||
"[óu]",
|
||||
"[ù]",
|
||||
"[ùn]",
|
||||
"[ú]",
|
||||
"[ún]",
|
||||
"[ā]",
|
||||
"[āi]",
|
||||
"[ān]",
|
||||
"[āng]",
|
||||
"[āo]",
|
||||
"[ē]",
|
||||
"[ēi]",
|
||||
"[ēn]",
|
||||
"[ēng]",
|
||||
"[ě]",
|
||||
"[ěi]",
|
||||
"[ěn]",
|
||||
"[ěng]",
|
||||
"[ěr]",
|
||||
"[ī]",
|
||||
"[īn]",
|
||||
"[īng]",
|
||||
"[ō]",
|
||||
"[ōng]",
|
||||
"[ōu]",
|
||||
"[ū]",
|
||||
"[ūn]",
|
||||
"[ǎ]",
|
||||
"[ǎi]",
|
||||
"[ǎn]",
|
||||
"[ǎng]",
|
||||
"[ǎo]",
|
||||
"[ǐ]",
|
||||
"[ǐn]",
|
||||
"[ǐng]",
|
||||
"[ǒ]",
|
||||
"[ǒng]",
|
||||
"[ǒu]",
|
||||
"[ǔ]",
|
||||
"[ǔn]",
|
||||
"[ǘ]",
|
||||
"[ǚ]",
|
||||
"[ǜ]",
|
||||
],
|
||||
}
|
||||
self.special_tokens = special_tokens
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(token_path)
|
||||
self.tokenizer.add_special_tokens(special_tokens)
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_qwen_tokenizer(
|
||||
token_path: str, skip_special_tokens: bool, version: str = "cosyvoice2"
|
||||
):
|
||||
if version == "cosyvoice2":
|
||||
return CosyVoice2Tokenizer(
|
||||
token_path=token_path, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
elif version == "cosyvoice3":
|
||||
return CosyVoice3Tokenizer(
|
||||
token_path=token_path, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .cosyvoice import *
|
||||
@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddlespeech.t2s.modules.transformer.activation import Swish
|
||||
from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
|
||||
from paddlespeech.t2s.modules.transformer.embedding import EspnetRelPositionalEncoding
|
||||
from paddlespeech.t2s.modules.transformer.subsampling import LinearNoSubsampling
|
||||
|
||||
|
||||
COSYVOICE_ACTIVATION_CLASSES = {
|
||||
"swish": Swish
|
||||
}
|
||||
COSYVOICE_SUBSAMPLE_CLASSES = {
|
||||
"linear": LinearNoSubsampling,
|
||||
}
|
||||
COSYVOICE_EMB_CLASSES = {
|
||||
"rel_pos_espnet": EspnetRelPositionalEncoding,
|
||||
}
|
||||
COSYVOICE_ATTENTION_CLASSES = {
|
||||
"rel_selfattn": RelPositionMultiHeadedAttention,
|
||||
}
|
||||
|
||||
@ -0,0 +1,222 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import queue
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
def device2str(type=None, index=None, *, device=None):
|
||||
type = device if device else type
|
||||
if isinstance(type, int):
|
||||
type = f'gpu:{type}'
|
||||
elif isinstance(type, str):
|
||||
if 'cuda' in type:
|
||||
type = type.replace('cuda', 'gpu')
|
||||
if 'cpu' in type:
|
||||
type = 'cpu'
|
||||
elif index is not None:
|
||||
type = f'{type}:{index}'
|
||||
elif isinstance(type, paddle.CPUPlace) or (type is None):
|
||||
type = 'cpu'
|
||||
elif isinstance(type, paddle.CUDAPlace):
|
||||
type = f'gpu:{type.get_device_id()}'
|
||||
|
||||
return type
|
||||
|
||||
|
||||
IGNORE_ID = -1
|
||||
|
||||
|
||||
def pad_list(xs: List[paddle.Tensor], pad_value: int):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
max_len = max([len(item) for item in xs])
|
||||
batchs = len(xs)
|
||||
ndim = xs[0].ndim
|
||||
if ndim == 1:
|
||||
pad_res = paddle.zeros(batchs, max_len, dtype=xs[0].dtype, device=xs[0].place)
|
||||
elif ndim == 2:
|
||||
pad_res = paddle.zeros(
|
||||
batchs, max_len, xs[0].shape[1], dtype=xs[0].dtype, device=xs[0].place
|
||||
)
|
||||
elif ndim == 3:
|
||||
pad_res = paddle.zeros(
|
||||
batchs,
|
||||
max_len,
|
||||
xs[0].shape[1],
|
||||
xs[0].shape[2],
|
||||
dtype=xs[0].dtype,
|
||||
device=xs[0].place,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ndim: {ndim}")
|
||||
pad_res.fill_(pad_value)
|
||||
for i in range(batchs):
|
||||
pad_res[i, : len(xs[i])] = xs[i]
|
||||
return pad_res
|
||||
|
||||
|
||||
def th_accuracy(
|
||||
pad_outputs: paddle.Tensor, pad_targets: paddle.Tensor, ignore_label: int
|
||||
) -> paddle.Tensor:
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(
|
||||
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
|
||||
).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = paddle.sum(
|
||||
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
|
||||
)
|
||||
denominator = paddle.sum(mask)
|
||||
return (numerator / denominator).detach()
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def ras_sampling(
|
||||
weighted_scores,
|
||||
decoded_tokens,
|
||||
sampling,
|
||||
top_p=0.8,
|
||||
top_k=25,
|
||||
win_size=10,
|
||||
tau_r=0.1,
|
||||
):
|
||||
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||
rep_num = (
|
||||
(paddle.to_tensor(decoded_tokens[-win_size:],dtype = paddle.long).to(weighted_scores.place) == top_ids)
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
if rep_num >= win_size * tau_r:
|
||||
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)[0]
|
||||
return top_ids
|
||||
|
||||
|
||||
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
||||
prob, indices = [], []
|
||||
cum_prob = 0.0
|
||||
sorted_value, sorted_idx = paddle.sort(
|
||||
descending=True, stable=True, x=weighted_scores.softmax(axis=0)
|
||||
), paddle.argsort(descending=True, stable=True, x=weighted_scores.softmax(axis=0))
|
||||
|
||||
for i in range(len(sorted_idx)):
|
||||
if cum_prob < top_p and len(prob) < top_k:
|
||||
cum_prob += sorted_value[i]
|
||||
prob.append(sorted_value[i])
|
||||
indices.append(sorted_idx[i])
|
||||
else:
|
||||
break
|
||||
prob = paddle.to_tensor(prob).cuda()
|
||||
indices = paddle.to_tensor(indices, dtype=paddle.long).to(weighted_scores.place)
|
||||
# top_ids = indices[prob.multinomial(num_samples=1, replacement=True)]
|
||||
top_ids = indices[0]
|
||||
return top_ids
|
||||
|
||||
|
||||
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
||||
top_ids = weighted_scores.softmax(axis=0).multinomial(
|
||||
num_samples=1, replacement=True
|
||||
)
|
||||
return top_ids
|
||||
|
||||
|
||||
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
||||
device = fade_in_mel.place
|
||||
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
||||
mel_overlap_len = int(window.shape[0] / 2)
|
||||
if fade_in_mel.place == device2str("cpu"):
|
||||
fade_in_mel = fade_in_mel.clone()
|
||||
fade_in_mel[..., :mel_overlap_len] = (
|
||||
fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len]
|
||||
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
)
|
||||
return fade_in_mel.to(device)
|
||||
|
||||
|
||||
def set_all_random_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
paddle.seed(seed)
|
||||
paddle.seed(seed)
|
||||
|
||||
|
||||
def mask_to_bias(mask: paddle.Tensor, dtype: paddle.dtype) -> paddle.Tensor:
|
||||
assert mask.dtype == paddle.bool
|
||||
assert dtype in [paddle.float32, paddle.bfloat16, paddle.float16]
|
||||
mask = mask.to(dtype)
|
||||
mask = (1.0 - mask) * -10000000000.0
|
||||
return mask
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device="cuda:0"):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = paddle.device.stream_guard(
|
||||
paddle.device.Stream(device=device2str(device))
|
||||
)
|
||||
assert (
|
||||
trt_context is not None
|
||||
), "failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}".format(
|
||||
trt_concurrent
|
||||
)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, "no avaialbe estimator context"
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
@ -0,0 +1,374 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Generator
|
||||
|
||||
import paddle
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from modelscope import snapshot_download
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
from paddlespeech.t2s.models.CosyVoice.frontend import CosyVoiceFrontEnd
|
||||
from paddlespeech.t2s.models.CosyVoice.model import CosyVoice2Model
|
||||
|
||||
def get_model_type(configs):
|
||||
# NOTE CosyVoice2Model inherits CosyVoiceModel
|
||||
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||
return CosyVoiceModel
|
||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||
return CosyVoice2Model
|
||||
raise TypeError('No valid model type found!')
|
||||
class CosyVoice:
|
||||
def __init__(
|
||||
self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1
|
||||
):
|
||||
self.instruct = True if "-Instruct" in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
if not os.path.exists(model_dir):
|
||||
model_dir = snapshot_download(model_dir)
|
||||
hyper_yaml_path = "{}/cosyvoice.yaml".format(model_dir)
|
||||
if not os.path.exists(hyper_yaml_path):
|
||||
raise ValueError("{} not found!".format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, "r") as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
# assert (
|
||||
# get_model_type(configs) != CosyVoice2Model
|
||||
# ), "do not use {} for CosyVoice initialization!".format(model_dir)
|
||||
self.frontend = CosyVoiceFrontEnd(
|
||||
configs["get_tokenizer"],
|
||||
configs["feat_extractor"],
|
||||
"{}/campplus.onnx".format(model_dir),
|
||||
"{}/speech_tokenizer_v1.onnx".format(model_dir),
|
||||
"{}/spk2info.pt".format(model_dir),
|
||||
configs["allowed_special"],
|
||||
)
|
||||
self.sample_rate = configs["sample_rate"]
|
||||
if (paddle.device.cuda.device_count() >= 1) is False and (
|
||||
load_jit is True or load_trt is True or fp16 is True
|
||||
):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning("no cuda device, set load_jit/load_trt/fp16 to False")
|
||||
self.model = CosyVoiceModel(
|
||||
configs["llm"], configs["flow"], configs["hift"], fp16
|
||||
)
|
||||
self.model.load(
|
||||
"{}/llm.pt".format(model_dir),
|
||||
"{}/flow.pt".format(model_dir),
|
||||
"{}/hift.pt".format(model_dir),
|
||||
)
|
||||
if load_jit:
|
||||
self.model.load_jit(
|
||||
"{}/llm.text_encoder.{}.zip".format(
|
||||
model_dir, "fp16" if self.fp16 is True else "fp32"
|
||||
),
|
||||
"{}/llm.llm.{}.zip".format(
|
||||
model_dir, "fp16" if self.fp16 is True else "fp32"
|
||||
),
|
||||
"{}/flow.encoder.{}.zip".format(
|
||||
model_dir, "fp16" if self.fp16 is True else "fp32"
|
||||
),
|
||||
)
|
||||
if load_trt:
|
||||
self.model.load_trt(
|
||||
"{}/flow.decoder.estimator.{}.mygpu.plan".format(
|
||||
model_dir, "fp16" if self.fp16 is True else "fp32"
|
||||
),
|
||||
"{}/flow.decoder.estimator.fp32.onnx".format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16,
|
||||
)
|
||||
del configs
|
||||
|
||||
def list_available_spks(self):
|
||||
spks = list(self.frontend.spk2info.keys())
|
||||
return spks
|
||||
|
||||
def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id):
|
||||
assert zero_shot_spk_id != "", "do not use empty zero_shot_spk_id"
|
||||
model_input = self.frontend.frontend_zero_shot(
|
||||
"", prompt_text, prompt_speech_16k, self.sample_rate, ""
|
||||
)
|
||||
del model_input["text"]
|
||||
del model_input["text_len"]
|
||||
self.frontend.spk2info[zero_shot_spk_id] = model_input
|
||||
return True
|
||||
|
||||
def save_spkinfo(self):
|
||||
paddle.save(
|
||||
obj=self.frontend.spk2info, path="{}/spk2info.pt".format(self.model_dir)
|
||||
)
|
||||
|
||||
def inference_sft(
|
||||
self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True
|
||||
):
|
||||
for i in tqdm(
|
||||
self.frontend.text_normalize(
|
||||
tts_text, split=True, text_frontend=text_frontend
|
||||
)
|
||||
):
|
||||
model_input = self.frontend.frontend_sft(i, spk_id)
|
||||
start_time = time.time()
|
||||
logging.info("synthesis text {}".format(i))
|
||||
for model_output in self.model.tts(
|
||||
**model_input, stream=stream, speed=speed
|
||||
):
|
||||
speech_len = model_output["tts_speech"].shape[1] / self.sample_rate
|
||||
logging.info(
|
||||
"yield speech len {}, rtf {}".format(
|
||||
speech_len, (time.time() - start_time) / speech_len
|
||||
)
|
||||
)
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_zero_shot(
|
||||
self,
|
||||
tts_text,
|
||||
prompt_text,
|
||||
prompt_speech_16k,
|
||||
zero_shot_spk_id="",
|
||||
stream=False,
|
||||
speed=1.0,
|
||||
text_frontend=True,
|
||||
):
|
||||
prompt_text = self.frontend.text_normalize(
|
||||
prompt_text, split=False, text_frontend=text_frontend
|
||||
)
|
||||
for i in tqdm(
|
||||
self.frontend.text_normalize(
|
||||
tts_text, split=True, text_frontend=text_frontend
|
||||
)
|
||||
):
|
||||
if not isinstance(i, Generator) and len(i) < 0.5 * len(prompt_text):
|
||||
logging.warning(
|
||||
"synthesis text {} too short than prompt text {}, this may lead to bad performance".format(
|
||||
i, prompt_text
|
||||
)
|
||||
)
|
||||
model_input = self.frontend.frontend_zero_shot(
|
||||
i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id
|
||||
)
|
||||
start_time = time.time()
|
||||
logging.info("synthesis text {}".format(i))
|
||||
for model_output in self.model.tts(
|
||||
**model_input, stream=stream, speed=speed
|
||||
):
|
||||
speech_len = model_output["tts_speech"].shape[1] / self.sample_rate
|
||||
logging.info(
|
||||
"yield speech len {}, rtf {}".format(
|
||||
speech_len, (time.time() - start_time) / speech_len
|
||||
)
|
||||
)
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_cross_lingual(
|
||||
self,
|
||||
tts_text,
|
||||
prompt_speech_16k,
|
||||
zero_shot_spk_id="",
|
||||
stream=False,
|
||||
speed=1.0,
|
||||
text_frontend=True,
|
||||
):
|
||||
for i in tqdm(
|
||||
self.frontend.text_normalize(
|
||||
tts_text, split=True, text_frontend=text_frontend
|
||||
)
|
||||
):
|
||||
model_input = self.frontend.frontend_cross_lingual(
|
||||
i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id
|
||||
)
|
||||
start_time = time.time()
|
||||
logging.info("synthesis text {}".format(i))
|
||||
for model_output in self.model.tts(
|
||||
**model_input, stream=stream, speed=speed
|
||||
):
|
||||
speech_len = model_output["tts_speech"].shape[1] / self.sample_rate
|
||||
logging.info(
|
||||
"yield speech len {}, rtf {}".format(
|
||||
speech_len, (time.time() - start_time) / speech_len
|
||||
)
|
||||
)
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_instruct(
|
||||
self,
|
||||
tts_text,
|
||||
spk_id,
|
||||
instruct_text,
|
||||
stream=False,
|
||||
speed=1.0,
|
||||
text_frontend=True,
|
||||
):
|
||||
assert isinstance(
|
||||
self.model, CosyVoiceModel
|
||||
), "inference_instruct is only implemented for CosyVoice!"
|
||||
if self.instruct is False:
|
||||
raise ValueError(
|
||||
"{} do not support instruct inference".format(self.model_dir)
|
||||
)
|
||||
instruct_text = self.frontend.text_normalize(
|
||||
instruct_text, split=False, text_frontend=text_frontend
|
||||
)
|
||||
for i in tqdm(
|
||||
self.frontend.text_normalize(
|
||||
tts_text, split=True, text_frontend=text_frontend
|
||||
)
|
||||
):
|
||||
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
||||
start_time = time.time()
|
||||
logging.info("synthesis text {}".format(i))
|
||||
for model_output in self.model.tts(
|
||||
**model_input, stream=stream, speed=speed
|
||||
):
|
||||
speech_len = model_output["tts_speech"].shape[1] / self.sample_rate
|
||||
logging.info(
|
||||
"yield speech len {}, rtf {}".format(
|
||||
speech_len, (time.time() - start_time) / speech_len
|
||||
)
|
||||
)
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_vc(
|
||||
self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0
|
||||
):
|
||||
model_input = self.frontend.frontend_vc(
|
||||
source_speech_16k, prompt_speech_16k, self.sample_rate
|
||||
)
|
||||
start_time = time.time()
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output["tts_speech"].shape[1] / self.sample_rate
|
||||
logging.info(
|
||||
"yield speech len {}, rtf {}".format(
|
||||
speech_len, (time.time() - start_time) / speech_len
|
||||
)
|
||||
)
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
class CosyVoice2(CosyVoice):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir,
|
||||
load_jit=False,
|
||||
load_trt=False,
|
||||
load_vllm=False,
|
||||
fp16=False,
|
||||
trt_concurrent=1,
|
||||
):
|
||||
self.instruct = True if "-Instruct" in model_dir else False
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
hyper_yaml_path = "{}/cosyvoice2.yaml".format(model_dir)
|
||||
if not os.path.exists(hyper_yaml_path):
|
||||
raise ValueError("{} not found!".format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, "r") as f:
|
||||
configs = load_hyperpyyaml(
|
||||
f,
|
||||
overrides={
|
||||
"qwen_pretrain_path": os.path.join(model_dir, "CosyVoice-BlankEN")
|
||||
},
|
||||
)
|
||||
# assert (
|
||||
# get_model_type(configs) == CosyVoice2Model
|
||||
# ), "do not use {} for CosyVoice2 initialization!".format(model_dir)
|
||||
self.frontend = CosyVoiceFrontEnd(
|
||||
configs["get_tokenizer"],
|
||||
configs["feat_extractor"],
|
||||
"{}/campplus.onnx".format(model_dir),
|
||||
"{}/speech_tokenizer_v2.onnx".format(model_dir),
|
||||
"{}/spk2info.pt".format(model_dir),
|
||||
configs["allowed_special"],
|
||||
)
|
||||
self.sample_rate = configs["sample_rate"]
|
||||
if (paddle.device.cuda.device_count() >= 1) is False and (
|
||||
load_jit is True or load_trt is True or fp16 is True
|
||||
):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning("no cuda device, set load_jit/load_trt/fp16 to False")
|
||||
self.model = CosyVoice2Model(
|
||||
configs["llm"], configs["flow"], configs["hift"], fp16
|
||||
)
|
||||
self.model.load(
|
||||
"{}/llm.pt".format(model_dir),
|
||||
"{}/flow.pt".format(model_dir),
|
||||
"{}/hift.pt".format(model_dir),
|
||||
)
|
||||
if load_vllm:
|
||||
self.model.load_vllm("{}/vllm".format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit(
|
||||
"{}/flow.encoder.{}.zip".format(
|
||||
model_dir, "fp16" if self.fp16 is True else "fp32"
|
||||
)
|
||||
)
|
||||
if load_trt:
|
||||
self.model.load_trt(
|
||||
"{}/flow.decoder.estimator.{}.mygpu.plan".format(
|
||||
model_dir, "fp16" if self.fp16 is True else "fp32"
|
||||
),
|
||||
"{}/flow.decoder.estimator.fp32.onnx".format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16,
|
||||
)
|
||||
del configs
|
||||
|
||||
def inference_instruct(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"inference_instruct is not implemented for CosyVoice2!"
|
||||
)
|
||||
|
||||
def inference_instruct2(
|
||||
self,
|
||||
tts_text,
|
||||
instruct_text,
|
||||
prompt_speech_16k,
|
||||
zero_shot_spk_id="",
|
||||
stream=False,
|
||||
speed=1.0,
|
||||
text_frontend=True,
|
||||
):
|
||||
assert isinstance(
|
||||
self.model, CosyVoice2Model
|
||||
), "inference_instruct2 is only implemented for CosyVoice2!"
|
||||
for i in tqdm(
|
||||
self.frontend.text_normalize(
|
||||
tts_text, split=True, text_frontend=text_frontend
|
||||
)
|
||||
):
|
||||
model_input = self.frontend.frontend_instruct2(
|
||||
i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id
|
||||
)
|
||||
start_time = time.time()
|
||||
logging.info("synthesis text {}".format(i))
|
||||
for model_output in self.model.tts(
|
||||
**model_input, stream=stream, speed=speed
|
||||
):
|
||||
speech_len = model_output["tts_speech"].shape[1] / self.sample_rate
|
||||
logging.info(
|
||||
"yield speech len {}, rtf {}".format(
|
||||
speech_len, (time.time() - start_time) / speech_len
|
||||
)
|
||||
)
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
@ -0,0 +1,267 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
class Decoder(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
down_block_type="transformer",
|
||||
mid_block_type="transformer",
|
||||
up_block_type="transformer",
|
||||
):
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
down_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
|
||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||
|
||||
for i in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
mid_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
||||
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i]
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
|
||||
resnet = ResnetBlock1D(
|
||||
dim=2 * input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
self.get_block(
|
||||
up_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True)
|
||||
if not is_last
|
||||
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
|
||||
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
||||
|
||||
self.final_block = Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
|
||||
self.initialize_weights()
|
||||
# nn.init.normal_(self.final_proj.weight)
|
||||
|
||||
@staticmethod
|
||||
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
|
||||
if block_type == "conformer":
|
||||
block = ConformerWrapper(
|
||||
dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_heads,
|
||||
ff_mult=1,
|
||||
conv_expansion_factor=2,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
conv_dropout=dropout,
|
||||
conv_kernel_size=31,
|
||||
)
|
||||
elif block_type == "transformer":
|
||||
block = BasicTransformerBlock(
|
||||
dim=dim,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type {block_type}")
|
||||
|
||||
return block
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x = resnet(x, mask_down, t)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask_down = rearrange(mask_down, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=mask_down,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
mask_down = rearrange(mask_down, "b t -> b 1 t")
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask_mid = rearrange(mask_mid, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=mask_mid,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
mask_mid = rearrange(mask_mid, "b t -> b 1 t")
|
||||
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c")
|
||||
mask_up = rearrange(mask_up, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=mask_up,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t")
|
||||
mask_up = rearrange(mask_up, "b t -> b 1 t")
|
||||
x = upsample(x * mask_up)
|
||||
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
|
||||
return output * mask
|
||||
@ -0,0 +1,735 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Dict, Generator, List, Optional
|
||||
import logging
|
||||
import paddle.nn.functional as F
|
||||
import paddle
|
||||
IGNORE_ID = -1
|
||||
import torch
|
||||
LabelSmoothingLoss = None
|
||||
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
|
||||
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||
recent_tokens = paddle.to_tensor(decoded_tokens[-win_size:], dtype='int64')
|
||||
rep_num = paddle.sum(recent_tokens.cpu() == top_ids.cpu()).cpu().item()
|
||||
if rep_num >= win_size * tau_r:
|
||||
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
||||
return top_ids
|
||||
|
||||
|
||||
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
||||
softmax_scores = paddle.nn.functional.softmax(weighted_scores, axis=0)
|
||||
sorted_indices = paddle.argsort(softmax_scores, axis=0, descending=True)
|
||||
sorted_probs = paddle.gather(softmax_scores, sorted_indices, axis=0)
|
||||
|
||||
prob_list = []
|
||||
indices_list = []
|
||||
cum_prob = 0.0
|
||||
|
||||
for i in range(len(sorted_indices)):
|
||||
if cum_prob < top_p and len(prob_list) < top_k:
|
||||
cum_prob += sorted_probs[i].item()
|
||||
prob_list.append(sorted_probs[i])
|
||||
indices_list.append(sorted_indices[i])
|
||||
else:
|
||||
break
|
||||
|
||||
prob_tensor = paddle.to_tensor(prob_list, dtype=weighted_scores.dtype)
|
||||
indices_tensor = paddle.to_tensor(indices_list, dtype='int64')
|
||||
top_ids = indices_tensor[paddle.multinomial(prob_tensor, num_samples=1, replacement=True)]
|
||||
|
||||
return top_ids
|
||||
|
||||
|
||||
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
||||
probs = paddle.nn.functional.softmax(weighted_scores, axis=0)
|
||||
top_ids = paddle.multinomial(probs, num_samples=1, replacement=True)
|
||||
return top_ids
|
||||
def make_pad_mask(lengths: paddle.Tensor, max_len: int = 0) -> paddle.Tensor:
|
||||
batch_size = lengths.shape[0]
|
||||
max_len = max_len if max_len > 0 else lengths.max().item()
|
||||
seq_range = paddle.arange(0, max_len, dtype='int64')
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len])
|
||||
seq_length_expand = lengths.unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
return mask
|
||||
|
||||
def th_accuracy(pad_outputs: paddle.Tensor, pad_targets: paddle.Tensor,
|
||||
ignore_label: int) -> paddle.Tensor:
|
||||
pad_pred = pad_outputs.reshape((pad_targets.shape[0], pad_targets.shape[1], -1)).argmax(axis=2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = paddle.sum((pad_pred[mask] == pad_targets[mask]).astype('float32'))
|
||||
denominator = paddle.sum(mask.astype('float32'))
|
||||
accuracy = numerator / denominator
|
||||
|
||||
return accuracy.detach()
|
||||
class TransformerLM(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder_input_size: int,
|
||||
llm_input_size: int,
|
||||
llm_output_size: int,
|
||||
text_token_size: int,
|
||||
speech_token_size: int,
|
||||
text_encoder: paddle.nn.Layer,
|
||||
llm: paddle.nn.Layer,
|
||||
sampling: Callable,
|
||||
length_normalized_loss: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
spk_embed_dim: int = 192,
|
||||
):
|
||||
super().__init__()
|
||||
self.llm_input_size = llm_input_size
|
||||
self.speech_token_size = speech_token_size
|
||||
self.text_embedding = paddle.nn.Embedding(
|
||||
text_token_size, text_encoder_input_size
|
||||
)
|
||||
self.text_encoder = text_encoder
|
||||
self.text_encoder_affine_layer = paddle.nn.Linear(
|
||||
in_features=self.text_encoder.output_size(), out_features=llm_input_size
|
||||
)
|
||||
self.sos_eos = 0
|
||||
self.task_id = 1
|
||||
self.llm_embedding = paddle.nn.Embedding(2, llm_input_size)
|
||||
self.llm = llm
|
||||
self.llm_decoder = paddle.nn.Linear(
|
||||
in_features=llm_output_size, out_features=speech_token_size + 1
|
||||
)
|
||||
|
||||
self.criterion_ce = LabelSmoothingLoss(
|
||||
size=speech_token_size + 1,
|
||||
padding_idx=IGNORE_ID,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
self.speech_embedding = paddle.nn.Embedding(speech_token_size, llm_input_size)
|
||||
self.spk_embed_affine_layer = paddle.nn.Linear(
|
||||
in_features=spk_embed_dim, out_features=llm_input_size
|
||||
)
|
||||
self.sampling = sampling
|
||||
|
||||
def encode(self, text: paddle.Tensor, text_lengths: paddle.Tensor):
|
||||
encoder_out, encoder_mask = self.text_encoder(
|
||||
text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1
|
||||
)
|
||||
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
||||
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def pad_unpad_sequence(
|
||||
self,
|
||||
sos_eos_emb,
|
||||
embedding,
|
||||
text_token,
|
||||
text_token_len,
|
||||
task_id_emb,
|
||||
speech_token,
|
||||
speech_token_len,
|
||||
):
|
||||
|
||||
text_token = paddle.static.nn.sequence_unpad(
|
||||
text_token, text_token_len.cpu()
|
||||
)
|
||||
speech_token = paddle.static.nn.sequence_unpad(
|
||||
speech_token, speech_token_len.cpu()
|
||||
)
|
||||
lm_input = [
|
||||
paddle.cat(
|
||||
[
|
||||
sos_eos_emb.squeeze(dim=0),
|
||||
embedding[i],
|
||||
text_token[i],
|
||||
task_id_emb.squeeze(dim=0),
|
||||
speech_token[i],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for i in range(len(text_token))
|
||||
]
|
||||
lm_input_len = paddle.tensor([i.size(0) for i in lm_input], dtype=paddle.int32)
|
||||
lm_input = paddle.static.nn.sequence_unpad(
|
||||
lm_input, batch_first=True, padding_value=IGNORE_ID
|
||||
)
|
||||
return lm_input, lm_input_len
|
||||
|
||||
def forward(
|
||||
self, batch: dict, device: torch.device
|
||||
) -> Dict[str, Optional[paddle.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
text: (B, L, D)
|
||||
text_lengths: (B,)
|
||||
audio: (B, T, N) or (B, T)
|
||||
audio_lengths: (B,)
|
||||
"""
|
||||
text_token = batch["text_token"].to(device)
|
||||
text_token_len = batch["text_token_len"].to(device)
|
||||
speech_token = batch["speech_token"].to(device)
|
||||
speech_token_len = batch["speech_token_len"].to(device)
|
||||
embedding = batch["embedding"].to(device)
|
||||
lm_target = [
|
||||
paddle.tensor(
|
||||
[IGNORE_ID] * (2 + text_token_len[i])
|
||||
+ speech_token[i, : speech_token_len[i]].tolist()
|
||||
+ [self.speech_token_size]
|
||||
)
|
||||
for i in range(text_token.size(0))
|
||||
]
|
||||
lm_target = torch.nn.utils.rnn.pad_sequence(
|
||||
lm_target, batch_first=True, padding_value=IGNORE_ID
|
||||
).to(device)
|
||||
text_token = self.text_embedding(text_token)
|
||||
text_token, text_token_len = self.encode(text_token, text_token_len)
|
||||
embedding = paddle.nn.functional.normalize(x=embedding, axis=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
embedding = embedding.unsqueeze(1)
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
speech_token = self.speech_embedding(speech_token)
|
||||
lm_input, lm_input_len = self.pad_unpad_sequence(
|
||||
sos_eos_emb,
|
||||
embedding,
|
||||
text_token,
|
||||
text_token_len,
|
||||
task_id_emb,
|
||||
speech_token,
|
||||
speech_token_len,
|
||||
)
|
||||
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||
logits = self.llm_decoder(lm_output)
|
||||
loss = self.criterion_ce(logits, lm_target)
|
||||
acc = th_accuracy(
|
||||
logits.view(-1, self.speech_token_size + 1),
|
||||
lm_target,
|
||||
ignore_label=IGNORE_ID,
|
||||
)
|
||||
return {"loss": loss, "acc": acc}
|
||||
|
||||
def sampling_ids(
|
||||
self,
|
||||
weighted_scores: paddle.Tensor,
|
||||
decoded_tokens: List,
|
||||
sampling: int,
|
||||
ignore_eos: bool = True,
|
||||
):
|
||||
num_trials, max_trials = 0, 100
|
||||
while True:
|
||||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||
if (not ignore_eos) or (top_ids < self.speech_token_size):
|
||||
break
|
||||
num_trials += 1
|
||||
if num_trials > max_trials:
|
||||
raise RuntimeError(
|
||||
"sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!".format(
|
||||
max_trials
|
||||
)
|
||||
)
|
||||
return top_ids
|
||||
|
||||
@paddle.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_len: paddle.Tensor,
|
||||
prompt_text: paddle.Tensor,
|
||||
prompt_text_len: paddle.Tensor,
|
||||
prompt_speech_token: paddle.Tensor,
|
||||
prompt_speech_token_len: paddle.Tensor,
|
||||
embedding: paddle.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = "",
|
||||
) -> Generator[paddle.Tensor, None, None]:
|
||||
device = text.place
|
||||
text = paddle.cat([prompt_text, text], dim=1)
|
||||
text_len += prompt_text_len
|
||||
text = self.text_embedding(text)
|
||||
text, text_len = self.encode(text, text_len)
|
||||
if embedding.shape[0] != 0:
|
||||
embedding = paddle.nn.functional.normalize(x=embedding, axis=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
embedding = embedding.unsqueeze(dim=1)
|
||||
else:
|
||||
embedding = (
|
||||
paddle.zeros(1, 0, self.llm_input_size, dtype=text.dtype)
|
||||
.to(device)
|
||||
.to(text.dtype)
|
||||
)
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = paddle.zeros(
|
||||
1, 0, self.llm_input_size, dtype=text.dtype
|
||||
).to(device)
|
||||
lm_input = paddle.cat(
|
||||
[sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1
|
||||
)
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||
out_tokens = []
|
||||
offset = 0
|
||||
att_cache, cnn_cache = paddle.zeros(
|
||||
(0, 0, 0, 0), device=lm_input.place
|
||||
), paddle.zeros((0, 0, 0, 0), device=lm_input.place)
|
||||
for i in range(max_len):
|
||||
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(
|
||||
lm_input,
|
||||
offset=offset,
|
||||
required_cache_size=-1,
|
||||
att_cache=att_cache,
|
||||
cnn_cache=cnn_cache,
|
||||
att_mask=paddle.tril(
|
||||
paddle.ones(
|
||||
(1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.place
|
||||
)
|
||||
).to(paddle.bool),
|
||||
)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
if i == 0:
|
||||
logp[:, self.speech_token_size] = -float("inf")
|
||||
top_ids = self.sampling_ids(
|
||||
logp.squeeze(dim=0),
|
||||
out_tokens,
|
||||
sampling,
|
||||
ignore_eos=True if i < min_len else False,
|
||||
).item()
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
offset += lm_input.size(1)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
|
||||
class Qwen2Encoder(paddle.nn.Layer):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, xs: paddle.Tensor, xs_lens: paddle.Tensor):
|
||||
T = xs.size(1)
|
||||
masks = ~make_pad_mask(xs_lens, T)
|
||||
outs = self.model(
|
||||
inputs_embeds=xs,
|
||||
attention_mask=masks,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
return outs.hidden_states[-1], masks.unsqueeze(1)
|
||||
|
||||
def forward_one_step(self, xs, masks, cache=None,idx = 0):
|
||||
|
||||
input_masks = masks[:, -1, :]
|
||||
outs = self.model(
|
||||
inputs_embeds=xs,
|
||||
attention_mask=input_masks,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
past_key_values=cache,
|
||||
index =idx
|
||||
)
|
||||
xs = outs.hidden_states[-1]
|
||||
new_cache = outs.past_key_values
|
||||
xs = paddle.cast(xs, dtype = 'float32')
|
||||
return xs, new_cache
|
||||
|
||||
|
||||
class Qwen2LM(TransformerLM):
|
||||
def __init__(
|
||||
self,
|
||||
llm_input_size: int,
|
||||
llm_output_size: int,
|
||||
speech_token_size: int,
|
||||
llm: paddle.nn.Layer,
|
||||
sampling: Callable,
|
||||
length_normalized_loss: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
mix_ratio: List[int] = [5, 15],
|
||||
):
|
||||
paddle.nn.Layer.__init__(self)
|
||||
self.llm_input_size = llm_input_size
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
self.sos_eos = 0
|
||||
self.task_id = 1
|
||||
self.fill_token = 2
|
||||
self.llm_embedding = paddle.nn.Embedding(2, llm_input_size)
|
||||
self.llm = llm
|
||||
self.llm_decoder = paddle.nn.Linear(
|
||||
in_features=llm_output_size, out_features=speech_token_size + 3
|
||||
)
|
||||
self.speech_embedding = paddle.nn.Embedding(
|
||||
speech_token_size + 3, llm_input_size
|
||||
)
|
||||
self.sampling = sampling
|
||||
self.mix_ratio = mix_ratio
|
||||
self.stop_token_ids = [(speech_token_size + i) for i in range(3)]
|
||||
self.vllm_output_queue = {}
|
||||
|
||||
def prepare_lm_input_target(
|
||||
self,
|
||||
text_token,
|
||||
text_token_emb,
|
||||
text_token_len,
|
||||
speech_token,
|
||||
speech_token_emb,
|
||||
speech_token_len,
|
||||
):
|
||||
lm_target, lm_input = [], []
|
||||
text_token = torch.nn.utils.rnn.unpad_sequence(
|
||||
text_token, text_token_len.cpu(), batch_first=True
|
||||
)
|
||||
speech_token = torch.nn.utils.rnn.unpad_sequence(
|
||||
speech_token, speech_token_len.cpu(), batch_first=True
|
||||
)
|
||||
text_token_emb = torch.nn.utils.rnn.unpad_sequence(
|
||||
text_token_emb, text_token_len.cpu(), batch_first=True
|
||||
)
|
||||
speech_token_emb = torch.nn.utils.rnn.unpad_sequence(
|
||||
speech_token_emb, speech_token_len.cpu(), batch_first=True
|
||||
)
|
||||
for i in range(len(text_token)):
|
||||
if (
|
||||
random.random() < 0.5
|
||||
and speech_token_len[i] / text_token_len[i]
|
||||
> self.mix_ratio[1] / self.mix_ratio[0]
|
||||
):
|
||||
this_lm_target, this_lm_input = [], []
|
||||
this_lm_target.append(IGNORE_ID)
|
||||
this_lm_input.append(
|
||||
self.llm_embedding.weight[self.sos_eos].reshape(1, -1)
|
||||
)
|
||||
for j in range(
|
||||
((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()
|
||||
):
|
||||
this_text_token = text_token[i][
|
||||
j * self.mix_ratio[0] : (j + 1) * self.mix_ratio[0]
|
||||
].tolist()
|
||||
this_speech_token = speech_token[i][
|
||||
j * self.mix_ratio[1] : (j + 1) * self.mix_ratio[1]
|
||||
].tolist()
|
||||
if len(this_text_token) == self.mix_ratio[0]:
|
||||
assert len(this_speech_token) == self.mix_ratio[1]
|
||||
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
|
||||
this_lm_target += this_speech_token
|
||||
this_lm_target.append(self.speech_token_size + 2)
|
||||
this_lm_input.append(
|
||||
text_token_emb[i][
|
||||
j * self.mix_ratio[0] : (j + 1) * self.mix_ratio[0]
|
||||
]
|
||||
)
|
||||
this_lm_input.append(
|
||||
speech_token_emb[i][
|
||||
j * self.mix_ratio[1] : (j + 1) * self.mix_ratio[1]
|
||||
]
|
||||
)
|
||||
else:
|
||||
this_lm_target += [-1] * len(this_text_token)
|
||||
this_lm_target += speech_token[i][
|
||||
j * self.mix_ratio[1] :
|
||||
].tolist()
|
||||
this_lm_target.append(self.speech_token_size)
|
||||
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0] :])
|
||||
this_lm_input.append(
|
||||
self.llm_embedding.weight[self.task_id].reshape(1, -1)
|
||||
)
|
||||
this_lm_input.append(
|
||||
speech_token_emb[i][j * self.mix_ratio[1] :]
|
||||
)
|
||||
this_lm_target, this_lm_input = paddle.tensor(
|
||||
this_lm_target
|
||||
), paddle.cat(this_lm_input, dim=0)
|
||||
else:
|
||||
this_lm_target = paddle.tensor(
|
||||
[IGNORE_ID] * (1 + text_token_len[i])
|
||||
+ speech_token[i].tolist()
|
||||
+ [self.speech_token_size]
|
||||
)
|
||||
this_lm_input = paddle.cat(
|
||||
[
|
||||
self.llm_embedding.weight[self.sos_eos].reshape(1, -1),
|
||||
text_token_emb[i],
|
||||
self.llm_embedding.weight[self.task_id].reshape(1, -1),
|
||||
speech_token_emb[i],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
lm_target.append(this_lm_target)
|
||||
lm_input.append(this_lm_input)
|
||||
lm_input_len = paddle.tensor([i.size(0) for i in lm_input], dtype=paddle.int32)
|
||||
lm_input = torch.nn.utils.rnn.pad_sequence(
|
||||
lm_input, batch_first=True, padding_value=IGNORE_ID
|
||||
)
|
||||
lm_target = torch.nn.utils.rnn.pad_sequence(
|
||||
lm_target, batch_first=True, padding_value=IGNORE_ID
|
||||
)
|
||||
return lm_target, lm_input, lm_input_len
|
||||
|
||||
@paddle.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_len: paddle.Tensor,
|
||||
prompt_text: paddle.Tensor,
|
||||
prompt_text_len: paddle.Tensor,
|
||||
prompt_speech_token: paddle.Tensor,
|
||||
prompt_speech_token_len: paddle.Tensor,
|
||||
embedding: paddle.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = "",
|
||||
) -> Generator[paddle.Tensor, None, None]:
|
||||
device = text.place
|
||||
text = paddle.cat([prompt_text, text], dim=1)
|
||||
text_len += prompt_text_len
|
||||
text = self.llm.model.qwen2.embed_tokens(text)
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape([1, 1, -1])
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape([1, 1, -1])
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = paddle.zeros(
|
||||
1, 0, self.llm_input_size, dtype=text.dtype
|
||||
).to(device)
|
||||
text = paddle.cast(text,dtype = 'float32')
|
||||
lm_input = paddle.cat(
|
||||
[sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1
|
||||
)
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
||||
yield token
|
||||
|
||||
@paddle.no_grad()
|
||||
def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
|
||||
if hasattr(self, "vllm"):
|
||||
from vllm import RequestOutput, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
top_k=sampling,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
min_tokens=min_len,
|
||||
max_tokens=max_len,
|
||||
)
|
||||
with self.lock:
|
||||
self.vllm.add_request(
|
||||
uuid,
|
||||
{
|
||||
"prompt_embeds": lm_input.squeeze(0)
|
||||
.to(paddle.bfloat16)
|
||||
.to(lm_input.place)
|
||||
},
|
||||
sampling_params,
|
||||
)
|
||||
self.vllm_output_queue[uuid] = queue.Queue()
|
||||
out_tokens = []
|
||||
while True:
|
||||
with self.lock:
|
||||
if self.vllm_output_queue[uuid].empty() is True:
|
||||
request_outputs: List[RequestOutput] = self.vllm.step()
|
||||
for request_output in request_outputs:
|
||||
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
||||
self.vllm_output_queue[request_output.request_id].put(
|
||||
top_ids
|
||||
)
|
||||
if self.vllm_output_queue[uuid].empty() is False:
|
||||
top_ids = self.vllm_output_queue[uuid].get()
|
||||
if top_ids in self.stop_token_ids:
|
||||
break
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
if len(out_tokens) == max_len:
|
||||
break
|
||||
time.sleep(0.001)
|
||||
with self.lock:
|
||||
self.vllm_output_queue.pop(uuid)
|
||||
else:
|
||||
out_tokens = []
|
||||
cache = None
|
||||
for i in range(max_len):
|
||||
|
||||
y_pred, cache = self.llm.forward_one_step(
|
||||
lm_input,
|
||||
masks=paddle.tril(
|
||||
paddle.ones(
|
||||
(1, lm_input.shape[1], lm_input.shape[1]),
|
||||
)
|
||||
).to(paddle.bool),
|
||||
cache=cache,
|
||||
idx = i
|
||||
)
|
||||
|
||||
logp = F.log_softmax(self.llm_decoder(y_pred[:, -1]), axis = -1)
|
||||
top_ids = self.sampling_ids(
|
||||
logp.squeeze(axis=0),
|
||||
out_tokens,
|
||||
sampling,
|
||||
ignore_eos=True if i < min_len else False,
|
||||
)
|
||||
if top_ids in self.stop_token_ids:
|
||||
break
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape([1, 1, -1])
|
||||
@paddle.no_grad()
|
||||
def inference_bistream(
|
||||
self,
|
||||
text: Generator,
|
||||
prompt_text: paddle.Tensor,
|
||||
prompt_text_len: paddle.Tensor,
|
||||
prompt_speech_token: paddle.Tensor,
|
||||
prompt_speech_token_len: paddle.Tensor,
|
||||
embedding: paddle.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
) -> Generator[paddle.Tensor, None, None]:
|
||||
device = prompt_text.place
|
||||
sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = paddle.zeros(
|
||||
1, 0, self.llm_input_size, dtype=prompt_text.dtype
|
||||
).to(device)
|
||||
lm_input = paddle.cat([sos_eos_emb], dim=1)
|
||||
out_tokens = []
|
||||
cache = None
|
||||
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
||||
next_fill_index = -1
|
||||
for this_text in text:
|
||||
text_cache = paddle.cat(
|
||||
[text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1
|
||||
)
|
||||
while prompt_speech_token_emb.size(1) != 0:
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text, lm_input_speech = (
|
||||
text_cache[:, : self.mix_ratio[0]],
|
||||
prompt_speech_token_emb[:, : self.mix_ratio[1]],
|
||||
)
|
||||
logging.info(
|
||||
"append {} text token {} speech token".format(
|
||||
lm_input_text.size(1), lm_input_speech.size(1)
|
||||
)
|
||||
)
|
||||
lm_input = paddle.cat(
|
||||
[lm_input, lm_input_text, lm_input_speech], dim=1
|
||||
)
|
||||
text_cache, prompt_speech_token_emb = (
|
||||
text_cache[:, self.mix_ratio[0] :],
|
||||
prompt_speech_token_emb[:, self.mix_ratio[1] :],
|
||||
)
|
||||
else:
|
||||
logging.info("not enough text token to decode, wait for more")
|
||||
break
|
||||
if prompt_speech_token_emb.size(1) == 0:
|
||||
if (
|
||||
len(out_tokens) != 0
|
||||
and out_tokens[-1] == self.speech_token_size + 2
|
||||
or len(out_tokens) == 0
|
||||
and lm_input.size(1) == 1
|
||||
):
|
||||
logging.info("get fill token, need to append more text token")
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text = text_cache[:, : self.mix_ratio[0]]
|
||||
logging.info(
|
||||
"append {} text token".format(lm_input_text.size(1))
|
||||
)
|
||||
if (
|
||||
len(out_tokens) != 0
|
||||
and out_tokens[-1] == self.speech_token_size + 2
|
||||
):
|
||||
lm_input = lm_input_text
|
||||
else:
|
||||
lm_input = paddle.cat([lm_input, lm_input_text], dim=1)
|
||||
text_cache = text_cache[:, self.mix_ratio[0] :]
|
||||
else:
|
||||
logging.info("not enough text token to decode, wait for more")
|
||||
continue
|
||||
while True:
|
||||
seq_len = (
|
||||
lm_input.shape[1]
|
||||
if cache is None
|
||||
else lm_input.shape[1] + cache[0][0].size(2)
|
||||
)
|
||||
y_pred, cache = self.llm.forward_one_step(
|
||||
lm_input,
|
||||
masks=paddle.tril(
|
||||
paddle.ones((1, seq_len, seq_len), device=lm_input.place)
|
||||
).to(paddle.bool),
|
||||
cache=cache,
|
||||
)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
||||
top_ids = self.speech_token_size + 2
|
||||
next_fill_index += self.mix_ratio[1] + 1
|
||||
else:
|
||||
top_ids = self.sampling_ids(
|
||||
logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True
|
||||
).item()
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||
logging.info(
|
||||
"fill_token index {} next fill_token index {}".format(
|
||||
len(out_tokens), next_fill_index
|
||||
)
|
||||
)
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size + 2:
|
||||
break
|
||||
else:
|
||||
raise ValueError("should not get token {}".format(top_ids))
|
||||
yield top_ids
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
lm_input = paddle.cat([lm_input, text_cache, task_id_emb], dim=1)
|
||||
logging.info("no more text token, decode until met eos")
|
||||
while True:
|
||||
seq_len = (
|
||||
lm_input.shape[1]
|
||||
if cache is None
|
||||
else lm_input.shape[1] + cache[0][0].size(2)
|
||||
)
|
||||
y_pred, cache = self.llm.forward_one_step(
|
||||
lm_input,
|
||||
masks=paddle.tril(
|
||||
paddle.ones((1, seq_len, seq_len), device=lm_input.place)
|
||||
).to(paddle.bool),
|
||||
cache=cache,
|
||||
)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(
|
||||
logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False
|
||||
).item()
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.speech_token_size:
|
||||
break
|
||||
else:
|
||||
raise ValueError("should not get token {}".format(top_ids))
|
||||
yield top_ids
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
|
||||
@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
def add_optional_chunk_mask(
|
||||
xs: paddle.Tensor,
|
||||
masks: paddle.Tensor,
|
||||
use_dynamic_chunk: bool,
|
||||
use_dynamic_left_chunk: bool,
|
||||
decoding_chunk_size: int,
|
||||
static_chunk_size: int,
|
||||
num_decoding_left_chunks: int,
|
||||
enable_full_context: bool = True,
|
||||
):
|
||||
"""Apply optional mask for encoder.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
||||
mask (torch.Tensor): mask for xs, (B, 1, L)
|
||||
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
||||
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
||||
training.
|
||||
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
static_chunk_size (int): chunk size for static chunk training/decoding
|
||||
if it's greater than 0, if use_dynamic_chunk is true,
|
||||
this parameter will be ignored
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
enable_full_context (bool):
|
||||
True: chunk size is either [1, 25] or full context(max_len)
|
||||
False: chunk size ~ U[1, 25]
|
||||
|
||||
Returns:
|
||||
torch.Tensor: chunk mask of the input xs.
|
||||
"""
|
||||
if use_dynamic_chunk:
|
||||
max_len = xs.size(1)
|
||||
if decoding_chunk_size < 0:
|
||||
chunk_size = max_len
|
||||
num_left_chunks = -1
|
||||
elif decoding_chunk_size > 0:
|
||||
chunk_size = decoding_chunk_size
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
else:
|
||||
chunk_size = paddle.randint(low=1, high=max_len, shape=(1,)).item()
|
||||
num_left_chunks = -1
|
||||
if chunk_size > max_len // 2 and enable_full_context:
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % 25 + 1
|
||||
if use_dynamic_left_chunk:
|
||||
max_left_chunks = (max_len - 1) // chunk_size
|
||||
num_left_chunks = paddle.randint(
|
||||
low=0, high=max_left_chunks, shape=(1,)
|
||||
).item()
|
||||
chunk_masks = subsequent_chunk_mask(
|
||||
xs.size(1), chunk_size, num_left_chunks, xs.place
|
||||
)
|
||||
chunk_masks = chunk_masks.unsqueeze(0)
|
||||
chunk_masks = masks & chunk_masks
|
||||
elif static_chunk_size > 0:
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
chunk_masks = subsequent_chunk_mask(
|
||||
xs.size(1), static_chunk_size, num_left_chunks, xs.place
|
||||
)
|
||||
chunk_masks = chunk_masks.unsqueeze(0)
|
||||
chunk_masks = masks & chunk_masks
|
||||
else:
|
||||
chunk_masks = masks
|
||||
assert chunk_masks.dtype == paddle.bool
|
||||
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
||||
print(
|
||||
"get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!"
|
||||
)
|
||||
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
||||
return chunk_masks
|
||||
|
||||
|
||||
def make_pad_mask(lengths: paddle.Tensor, max_len: int = 0) -> paddle.Tensor:
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
See description of make_non_pad_mask.
|
||||
|
||||
Args:
|
||||
lengths (torch.Tensor): Batch of lengths (B,).
|
||||
Returns:
|
||||
torch.Tensor: Mask tensor containing indices of padded part.
|
||||
|
||||
Examples:
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
"""
|
||||
batch_size = lengths.shape[0]
|
||||
max_len = max_len if max_len > 0 else lengths.max().item()
|
||||
seq_range = paddle.arange(0, max_len, dtype=paddle.int32)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len])
|
||||
seq_length_expand = lengths.unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
return mask
|
||||
@ -0,0 +1,607 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import nullcontext
|
||||
from typing import Generator
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
|
||||
class CosyVoiceModel:
|
||||
def __init__(
|
||||
self,
|
||||
llm: paddle.nn.Layer,
|
||||
flow: paddle.nn.Layer,
|
||||
hift: paddle.nn.Layer,
|
||||
fp16: bool = False,
|
||||
):
|
||||
self.device = device2str(
|
||||
"cuda" if paddle.device.cuda.device_count() >= 1 else "cpu"
|
||||
)
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
||||
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
||||
self.token_overlap_len = 20
|
||||
self.mel_overlap_len = int(
|
||||
self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256
|
||||
)
|
||||
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
||||
self.mel_cache_len = 20
|
||||
self.source_cache_len = int(self.mel_cache_len * 256)
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
self.stream_scale_factor = 1
|
||||
assert (
|
||||
self.stream_scale_factor >= 1
|
||||
), "stream_scale_factor should be greater than 1, change it according to your actual rtf"
|
||||
self.llm_context = (
|
||||
paddle.device.stream_guard(
|
||||
paddle.device.Stream(device=device2str(self.device))
|
||||
)
|
||||
if paddle.device.cuda.device_count() >= 1
|
||||
else nullcontext()
|
||||
)
|
||||
self.lock = threading.Lock()
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.mel_overlap_dict = {}
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.set_state_dict(state_dict=paddle.load(path=str(llm_model)))
|
||||
self.llm.to(self.device).eval()
|
||||
self.flow.set_state_dict(state_dict=paddle.load(path=str(flow_model)))
|
||||
self.flow.to(self.device).eval()
|
||||
hift_state_dict = {
|
||||
k.replace("generator.", ""): v
|
||||
for k, v in paddle.load(path=str(hift_model)).items()
|
||||
}
|
||||
self.hift.set_state_dict(state_dict=hift_state_dict)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
||||
llm_text_encoder = torch.jit.load(
|
||||
llm_text_encoder_model, map_location=self.device
|
||||
)
|
||||
self.llm.text_encoder = llm_text_encoder
|
||||
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
||||
self.llm.llm = llm_llm
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_trt(
|
||||
self,
|
||||
flow_decoder_estimator_model,
|
||||
flow_decoder_onnx_model,
|
||||
trt_concurrent,
|
||||
fp16,
|
||||
):
|
||||
assert paddle.device.cuda.device_count() >= 1, "tensorrt only supports gpu!"
|
||||
if (
|
||||
not os.path.exists(flow_decoder_estimator_model)
|
||||
or os.path.getsize(flow_decoder_estimator_model) == 0
|
||||
):
|
||||
convert_onnx_to_trt(
|
||||
flow_decoder_estimator_model,
|
||||
self.get_trt_kwargs(),
|
||||
flow_decoder_onnx_model,
|
||||
fp16,
|
||||
)
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
|
||||
with open(flow_decoder_estimator_model, "rb") as f:
|
||||
estimator_engine = trt.Runtime(
|
||||
trt.Logger(trt.Logger.INFO)
|
||||
).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, "failed to load trt {}".format(
|
||||
flow_decoder_estimator_model
|
||||
)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(
|
||||
estimator_engine, trt_concurrent=trt_concurrent, device=self.device
|
||||
)
|
||||
|
||||
def get_trt_kwargs(self):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
|
||||
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
||||
input_names = ["x", "mask", "mu", "cond"]
|
||||
return {
|
||||
"min_shape": min_shape,
|
||||
"opt_shape": opt_shape,
|
||||
"max_shape": max_shape,
|
||||
"input_names": input_names,
|
||||
}
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
with self.llm_context, paddle.amp.auto_cast(
|
||||
enable=self.fp16 is True and hasattr(self.llm, "vllm") is False
|
||||
):
|
||||
if isinstance(text, Generator):
|
||||
assert isinstance(self, CosyVoice2Model) and not hasattr(
|
||||
self.llm, "vllm"
|
||||
), "streaming input text is only implemented for CosyVoice2 and do not support vllm!"
|
||||
for i in self.llm.inference_bistream(
|
||||
text=text,
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=paddle.tensor(
|
||||
[prompt_text.shape[1]], dtype=paddle.int32
|
||||
).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=paddle.tensor(
|
||||
[llm_prompt_speech_token.shape[1]], dtype=paddle.int32
|
||||
).to(self.device),
|
||||
embedding=llm_embedding.to(self.device),
|
||||
):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
else:
|
||||
for i in self.llm.inference(
|
||||
text=text.to(self.device),
|
||||
text_len=paddle.tensor([text.shape[1]], dtype=paddle.int32).to(
|
||||
self.device
|
||||
),
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=paddle.tensor(
|
||||
[prompt_text.shape[1]], dtype=paddle.int32
|
||||
).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=paddle.tensor(
|
||||
[llm_prompt_speech_token.shape[1]], dtype=paddle.int32
|
||||
).to(self.device),
|
||||
embedding=llm_embedding.to(self.device),
|
||||
uuid=uuid,
|
||||
):
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
def vc_job(self, source_speech_token, uuid):
|
||||
self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
def token2wav(
|
||||
self,
|
||||
token,
|
||||
prompt_token,
|
||||
prompt_feat,
|
||||
embedding,
|
||||
uuid,
|
||||
finalize=False,
|
||||
speed=1.0,
|
||||
):
|
||||
with paddle.amp.auto_cast(enable=self.fp16):
|
||||
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(
|
||||
token=token.to(self.device),
|
||||
token_len=paddle.tensor([token.shape[1]], dtype=paddle.int32).to(
|
||||
self.device
|
||||
),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=paddle.tensor(
|
||||
[prompt_token.shape[1]], dtype=paddle.int32
|
||||
).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=paddle.tensor(
|
||||
[prompt_feat.shape[1]], dtype=paddle.int32
|
||||
).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
flow_cache=self.flow_cache_dict[uuid],
|
||||
)
|
||||
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
||||
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel, hift_cache_source = (
|
||||
self.hift_cache_dict[uuid]["mel"],
|
||||
self.hift_cache_dict[uuid]["source"],
|
||||
)
|
||||
tts_mel = paddle.cat([hift_cache_mel, tts_mel], dim=2)
|
||||
else:
|
||||
hift_cache_source = paddle.zeros([1, 1, 0])
|
||||
if finalize is False:
|
||||
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len :]
|
||||
tts_mel = tts_mel[:, :, : -self.mel_overlap_len]
|
||||
tts_speech, tts_source = self.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=hift_cache_source
|
||||
)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(
|
||||
tts_speech, self.hift_cache_dict[uuid]["speech"], self.speech_window
|
||||
)
|
||||
self.hift_cache_dict[uuid] = {
|
||||
"mel": tts_mel[:, :, -self.mel_cache_len :],
|
||||
"source": tts_source[:, :, -self.source_cache_len :],
|
||||
"speech": tts_speech[:, -self.source_cache_len :],
|
||||
}
|
||||
tts_speech = tts_speech[:, : -self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
assert (
|
||||
self.hift_cache_dict[uuid] is None
|
||||
), "speed change only support non-stream inference mode"
|
||||
tts_mel = paddle.nn.functional.interpolate(
|
||||
x=tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear"
|
||||
)
|
||||
tts_speech, tts_source = self.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=hift_cache_source
|
||||
)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(
|
||||
tts_speech, self.hift_cache_dict[uuid]["speech"], self.speech_window
|
||||
)
|
||||
return tts_speech
|
||||
|
||||
def tts(
|
||||
self,
|
||||
text=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
flow_embedding=paddle.zeros([0, 192]),
|
||||
llm_embedding=paddle.zeros([0, 192]),
|
||||
prompt_text=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
llm_prompt_speech_token=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
flow_prompt_speech_token=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
prompt_speech_feat=paddle.zeros([1, 0, 80]),
|
||||
source_speech_token=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
stream=False,
|
||||
speed=1.0,
|
||||
**kwargs
|
||||
):
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = (
|
||||
[],
|
||||
False,
|
||||
)
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.mel_overlap_dict[this_uuid] = paddle.zeros([1, 80, 0])
|
||||
self.flow_cache_dict[this_uuid] = paddle.zeros([1, 80, 0, 2])
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(
|
||||
target=self.llm_job,
|
||||
args=(
|
||||
text,
|
||||
prompt_text,
|
||||
llm_prompt_speech_token,
|
||||
llm_embedding,
|
||||
this_uuid,
|
||||
),
|
||||
)
|
||||
else:
|
||||
p = threading.Thread(
|
||||
target=self.vc_job, args=(source_speech_token, this_uuid)
|
||||
)
|
||||
"""Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself"""
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_hop_len = self.token_min_hop_len
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if (
|
||||
len(self.tts_speech_token_dict[this_uuid])
|
||||
>= token_hop_len + self.token_overlap_len
|
||||
):
|
||||
this_tts_speech_token = paddle.tensor(
|
||||
self.tts_speech_token_dict[this_uuid][
|
||||
: token_hop_len + self.token_overlap_len
|
||||
]
|
||||
).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(
|
||||
token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=False,
|
||||
)
|
||||
yield {"tts_speech": this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[
|
||||
this_uuid
|
||||
] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
||||
token_hop_len = min(
|
||||
self.token_max_hop_len,
|
||||
int(token_hop_len * self.stream_scale_factor),
|
||||
)
|
||||
if (
|
||||
self.llm_end_dict[this_uuid] is True
|
||||
and len(self.tts_speech_token_dict[this_uuid])
|
||||
< token_hop_len + self.token_overlap_len
|
||||
):
|
||||
break
|
||||
p.join()
|
||||
this_tts_speech_token = paddle.tensor(
|
||||
self.tts_speech_token_dict[this_uuid]
|
||||
).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(
|
||||
token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
)
|
||||
yield {"tts_speech": this_tts_speech.cpu()}
|
||||
else:
|
||||
p.join()
|
||||
this_tts_speech_token = paddle.tensor(
|
||||
self.tts_speech_token_dict[this_uuid]
|
||||
).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(
|
||||
token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed,
|
||||
)
|
||||
yield {"tts_speech": this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
if paddle.device.cuda.device_count() >= 1:
|
||||
paddle.device.cuda.empty_cache()
|
||||
paddle.device.current_stream().synchronize()
|
||||
|
||||
|
||||
class CosyVoice2Model(CosyVoiceModel):
|
||||
def __init__(
|
||||
self,
|
||||
llm: paddle.nn.Layer,
|
||||
flow: paddle.nn.Layer,
|
||||
hift: paddle.nn.Layer,
|
||||
fp16: bool = False,
|
||||
):
|
||||
self.device = device2str(
|
||||
"cuda" if paddle.device.cuda.device_count() >= 1 else "cpu"
|
||||
)
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
if self.fp16 is True:
|
||||
self.llm.half()
|
||||
self.flow.half()
|
||||
self.token_hop_len = 25
|
||||
self.mel_cache_len = 8
|
||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
self.llm_context = (
|
||||
paddle.device.stream_guard(
|
||||
paddle.device.Stream(device=device2str(self.device))
|
||||
)
|
||||
if paddle.device.cuda.device_count() >= 1
|
||||
else nullcontext()
|
||||
)
|
||||
self.lock = threading.Lock()
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_vllm(self, model_dir):
|
||||
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
|
||||
from vllm import EngineArgs, LLMEngine
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_dir,
|
||||
skip_tokenizer_init=True,
|
||||
enable_prompt_embeds=True,
|
||||
gpu_memory_utilization=0.2,
|
||||
)
|
||||
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
|
||||
self.llm.lock = threading.Lock()
|
||||
del self.llm.llm.model.model.layers
|
||||
|
||||
def token2wav(
|
||||
self,
|
||||
token,
|
||||
prompt_token,
|
||||
prompt_feat,
|
||||
embedding,
|
||||
token_offset,
|
||||
uuid,
|
||||
stream=False,
|
||||
finalize=False,
|
||||
speed=1.0,
|
||||
):
|
||||
with paddle.amp.auto_cast(enable=self.fp16):
|
||||
tts_mel, _ = self.flow.inference(
|
||||
token=token.to(self.device),
|
||||
token_len=paddle.tensor([token.shape[1]], dtype=paddle.int32).to(
|
||||
self.device
|
||||
),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=paddle.tensor(
|
||||
[prompt_token.shape[1]], dtype=paddle.int32
|
||||
).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=paddle.tensor(
|
||||
[prompt_feat.shape[1]], dtype=paddle.int32
|
||||
).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
streaming=stream,
|
||||
finalize=finalize,
|
||||
)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio :]
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel, hift_cache_source = (
|
||||
self.hift_cache_dict[uuid]["mel"],
|
||||
self.hift_cache_dict[uuid]["source"],
|
||||
)
|
||||
tts_mel = paddle.cat([hift_cache_mel, tts_mel], dim=2)
|
||||
else:
|
||||
hift_cache_source = paddle.zeros([1, 1, 0])
|
||||
if finalize is False:
|
||||
tts_speech, tts_source = self.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=hift_cache_source
|
||||
)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(
|
||||
tts_speech, self.hift_cache_dict[uuid]["speech"], self.speech_window
|
||||
)
|
||||
self.hift_cache_dict[uuid] = {
|
||||
"mel": tts_mel[:, :, -self.mel_cache_len :],
|
||||
"source": tts_source[:, :, -self.source_cache_len :],
|
||||
"speech": tts_speech[:, -self.source_cache_len :],
|
||||
}
|
||||
tts_speech = tts_speech[:, : -self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
assert (
|
||||
self.hift_cache_dict[uuid] is None
|
||||
), "speed change only support non-stream inference mode"
|
||||
tts_mel = paddle.nn.functional.interpolate(
|
||||
x=tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear"
|
||||
)
|
||||
tts_speech, tts_source = self.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=hift_cache_source
|
||||
)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(
|
||||
tts_speech, self.hift_cache_dict[uuid]["speech"], self.speech_window
|
||||
)
|
||||
return tts_speech
|
||||
|
||||
def tts(
|
||||
self,
|
||||
text=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
flow_embedding=paddle.zeros([0, 192]),
|
||||
llm_embedding=paddle.zeros([0, 192]),
|
||||
prompt_text=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
llm_prompt_speech_token=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
flow_prompt_speech_token=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
prompt_speech_feat=paddle.zeros([1, 0, 80]),
|
||||
source_speech_token=paddle.zeros([1, 0], dtype=paddle.int32),
|
||||
stream=False,
|
||||
speed=1.0,
|
||||
**kwargs
|
||||
):
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = (
|
||||
[],
|
||||
False,
|
||||
)
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(
|
||||
target=self.llm_job,
|
||||
args=(
|
||||
text,
|
||||
prompt_text,
|
||||
llm_prompt_speech_token,
|
||||
llm_embedding,
|
||||
this_uuid,
|
||||
),
|
||||
)
|
||||
else:
|
||||
p = threading.Thread(
|
||||
target=self.vc_job, args=(source_speech_token, this_uuid)
|
||||
)
|
||||
"""Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself"""
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_offset = 0
|
||||
prompt_token_pad = int(
|
||||
np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len)
|
||||
* self.token_hop_len
|
||||
- flow_prompt_speech_token.shape[1]
|
||||
)
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
this_token_hop_len = (
|
||||
self.token_hop_len + prompt_token_pad
|
||||
if token_offset == 0
|
||||
else self.token_hop_len
|
||||
)
|
||||
if (
|
||||
len(self.tts_speech_token_dict[this_uuid]) - token_offset
|
||||
>= this_token_hop_len + self.flow.pre_lookahead_len
|
||||
):
|
||||
this_tts_speech_token = paddle.tensor(
|
||||
self.tts_speech_token_dict[this_uuid][
|
||||
: token_offset
|
||||
+ this_token_hop_len
|
||||
+ self.flow.pre_lookahead_len
|
||||
]
|
||||
).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(
|
||||
token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
token_offset=token_offset,
|
||||
uuid=this_uuid,
|
||||
stream=stream,
|
||||
finalize=False,
|
||||
)
|
||||
token_offset += this_token_hop_len
|
||||
yield {"tts_speech": this_tts_speech.cpu()}
|
||||
if (
|
||||
self.llm_end_dict[this_uuid] is True
|
||||
and len(self.tts_speech_token_dict[this_uuid]) - token_offset
|
||||
< this_token_hop_len + self.flow.pre_lookahead_len
|
||||
):
|
||||
break
|
||||
p.join()
|
||||
this_tts_speech_token = paddle.tensor(
|
||||
self.tts_speech_token_dict[this_uuid]
|
||||
).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(
|
||||
token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
token_offset=token_offset,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
)
|
||||
yield {"tts_speech": this_tts_speech.cpu()}
|
||||
else:
|
||||
p.join()
|
||||
this_tts_speech_token = paddle.tensor(
|
||||
self.tts_speech_token_dict[this_uuid]
|
||||
).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(
|
||||
token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
token_offset=0,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed,
|
||||
)
|
||||
yield {"tts_speech": this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
if paddle.device.cuda.device_count() >= 1:
|
||||
paddle.device.cuda.empty_cache()
|
||||
paddle.device.current_stream().synchronize()
|
||||
@ -0,0 +1,607 @@
|
||||
import paddle
|
||||
|
||||
"""HIFI-GAN"""
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from scipy.signal import get_window
|
||||
from paddlespeech.t2s.modules.transformer.activation import Snake
|
||||
from paddlespeech.t2s.models.CosyVoice.common import get_padding, init_weights
|
||||
|
||||
"""hifigan based generator implementation.
|
||||
|
||||
This code is modified from https://github.com/jik876/hifi-gan
|
||||
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
||||
https://github.com/NVIDIA/BigVGAN
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ResBlock(paddle.nn.Layer):
|
||||
"""Residual block module in HiFiGAN/BigVGAN."""
|
||||
|
||||
def __init__(self, channels: int=512, kernel_size: int=3, dilations:
|
||||
List[int]=[1, 3, 5]):
|
||||
super(ResBlock, self).__init__()
|
||||
self.convs1 = paddle.nn.LayerList()
|
||||
self.convs2 = paddle.nn.LayerList()
|
||||
for dilation in dilations:
|
||||
self.convs1.append(paddle.nn.Conv1D(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
|
||||
self.convs2.append(paddle.nn.Conv1D(channels, channels, kernel_size, 1, dilation=1,padding=get_padding(kernel_size, 1)))
|
||||
self.convs1.apply(init_weights)
|
||||
self.convs2.apply(init_weights)
|
||||
self.activations1 = paddle.nn.LayerList(sublayers=[Snake(channels,
|
||||
alpha_logscale=False) for _ in range(len(self.convs1))])
|
||||
self.activations2 = paddle.nn.LayerList(sublayers=[Snake(channels,
|
||||
alpha_logscale=False) for _ in range(len(self.convs2))])
|
||||
|
||||
def forward(self, x: paddle.Tensor) ->paddle.Tensor:
|
||||
for idx in range(len(self.convs1)):
|
||||
xt = self.activations1[idx](x)
|
||||
xt = self.convs1[idx](xt)
|
||||
xt = self.activations2[idx](xt)
|
||||
xt = self.convs2[idx](xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for idx in range(len(self.convs1)):
|
||||
paddle.nn.utils.remove_weight_norm(layer=self.convs1[idx])
|
||||
paddle.nn.utils.remove_weight_norm(layer=self.convs2[idx])
|
||||
|
||||
|
||||
class SineGen(paddle.nn.Layer):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=
|
||||
0.003, voiced_threshold=0):
|
||||
super(SineGen, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
|
||||
def _f02uv(self, f0):
|
||||
uv = (f0 > self.voiced_threshold).astype(paddle.float32)
|
||||
return uv
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward(self, f0):
|
||||
"""
|
||||
:param f0: [B, 1, sample_len], Hz
|
||||
:return: [B, 1, sample_len]
|
||||
"""
|
||||
F_mat = paddle.zeros([f0.size(0), self.harmonic_num + 1, f0.size(-1)]).to(f0.place)
|
||||
for i in range(self.harmonic_num + 1):
|
||||
F_mat[:, i:i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
||||
theta_mat = 2 * np.pi * (paddle.cumsum(F_mat, axis=-1) % 1)
|
||||
u_dist = paddle.distribution.Uniform(low=-np.pi, high=np.pi)
|
||||
phase_vec = u_dist.sample(shape=(f0.size(0), self.harmonic_num + 1, 1)
|
||||
).to(F_mat.place)
|
||||
phase_vec[:, 0, :] = 0
|
||||
sine_waves = self.sine_amp * paddle.sin(theta_mat + phase_vec)
|
||||
uv = self._f02uv(f0)
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * paddle.randn(shape=sine_waves.shape, dtype=
|
||||
sine_waves.dtype)
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
class SourceModuleHnNSF(paddle.nn.Layer):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
|
||||
super(SourceModuleHnNSF, self).__init__()
|
||||
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
if sinegen_type == '1':
|
||||
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
||||
else:
|
||||
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
|
||||
self.l_linear = paddle.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = paddle.nn.Tanh()
|
||||
self.causal = causal
|
||||
paddle.seed(1986)
|
||||
|
||||
if causal is True:
|
||||
self.uv = paddle.rand(shape=[1, 300 * 24000, 1])
|
||||
self.register_buffer('uv_buffer', self.uv)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
"""
|
||||
with paddle.no_grad():
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
|
||||
if not self.training and self.causal:
|
||||
noise = self.uv_buffer[:, :uv.shape[1]] * self.sine_amp / 3
|
||||
else:
|
||||
noise = paddle.randn(shape=uv.shape, dtype=uv.dtype) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
# class SourceModuleHnNSF(paddle.nn.Layer):
|
||||
# """ SourceModule for hn-nsf
|
||||
# SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
# add_noise_std=0.003, voiced_threshod=0)
|
||||
# sampling_rate: sampling_rate in Hz
|
||||
# harmonic_num: number of harmonic above F0 (default: 0)
|
||||
# sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
# add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
# note that amplitude of noise in unvoiced is decided
|
||||
# by sine_amp
|
||||
# voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
# Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
# F0_sampled (batchsize, length, 1)
|
||||
# Sine_source (batchsize, length, 1)
|
||||
# noise_source (batchsize, length 1)
|
||||
# uv (batchsize, length, 1)
|
||||
# """
|
||||
|
||||
# def __init__(self, sampling_rate, upsample_scale, harmonic_num=0,
|
||||
# sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
|
||||
# super(SourceModuleHnNSF, self).__init__()
|
||||
# self.sine_amp = sine_amp
|
||||
# self.noise_std = add_noise_std
|
||||
# self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp,
|
||||
# add_noise_std, voiced_threshod)
|
||||
# self.l_linear = paddle.nn.Linear(in_features=harmonic_num + 1,
|
||||
# out_features=1)
|
||||
# self.l_tanh = paddle.nn.Tanh()
|
||||
|
||||
# def forward(self, x):
|
||||
# """
|
||||
# Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
# F0_sampled (batchsize, length, 1)
|
||||
# Sine_source (batchsize, length, 1)
|
||||
# noise_source (batchsize, length 1)
|
||||
# """
|
||||
# with paddle.no_grad():
|
||||
# sine_wavs, uv, _ = self.l_sin_gen(paddle.transpose(x,perm=[0,2,1]))
|
||||
# sine_wavs = paddle.transpose(sine_wavs,perm=[0,2,1])
|
||||
# uv = paddle.transpose(uv,perm=[0,2,1])
|
||||
# sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
|
||||
# noise = paddle.randn(shape=uv.shape, dtype=uv.dtype
|
||||
# ) * self.sine_amp / 3
|
||||
# return sine_merge, noise, uv
|
||||
|
||||
|
||||
# class SineGen2(paddle.nn.Layer):
|
||||
# """ Definition of sine generator
|
||||
# SineGen(samp_rate, harmonic_num = 0,
|
||||
# sine_amp = 0.1, noise_std = 0.003,
|
||||
# voiced_threshold = 0,
|
||||
# flag_for_pulse=False)
|
||||
# samp_rate: sampling rate in Hz
|
||||
# harmonic_num: number of harmonic overtones (default 0)
|
||||
# sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
# noise_std: std of Gaussian noise (default 0.003)
|
||||
# voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
# flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
# Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
# segment is always sin(np.pi) or cos(0)
|
||||
# """
|
||||
|
||||
# def __init__(self, samp_rate, upsample_scale, harmonic_num=0, sine_amp=
|
||||
# 0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
|
||||
# super(SineGen2, self).__init__()
|
||||
# self.sine_amp = sine_amp
|
||||
# self.noise_std = noise_std
|
||||
# self.harmonic_num = harmonic_num
|
||||
# self.axis = self.harmonic_num + 1
|
||||
# self.sampling_rate = samp_rate
|
||||
# self.voiced_threshold = voiced_threshold
|
||||
# self.flag_for_pulse = flag_for_pulse
|
||||
# self.upsample_scale = upsample_scale
|
||||
|
||||
# def _f02uv(self, f0):
|
||||
# uv = (f0 > self.voiced_threshold).astype(paddle.float32)
|
||||
# return uv
|
||||
|
||||
# def _f02sine(self, f0_values):
|
||||
# """ f0_values: (batchsize, length, axis)
|
||||
# where axis indicates fundamental tone and overtones
|
||||
# """
|
||||
# rad_values = f0_values / self.sampling_rate % 1
|
||||
# rand_ini = paddle.rand(shape=[f0_values.shape[0], f0_values.shape[2]])
|
||||
# rand_ini[:, 0] = 0
|
||||
# rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
# if not self.flag_for_pulse:
|
||||
# x = paddle.transpose(rad_values,perm = [0,2,1])
|
||||
|
||||
# rad_values = paddle.transpose(paddle.nn.functional.interpolate(x=x, scale_factor=1 / self.upsample_scale, mode='linear'),perm = [0,2,1])
|
||||
# phase = paddle.cumsum(rad_values, axis=1) * 2 * np.pi
|
||||
# phase = paddle.transpose(paddle.nn.functional.interpolate(x=paddle.transpose(phase,perm = [0,2,1]) * self.upsample_scale, scale_factor=int(self.upsample_scale),mode='linear'),perm = [0,2,1])
|
||||
# sines = paddle.sin(phase)
|
||||
# else:
|
||||
# uv = self._f02uv(f0_values)
|
||||
# uv_1 = paddle.roll(uv, shifts=-1, axis=1)
|
||||
# uv_1[:, -1, :] = 1
|
||||
# u_loc = (uv < 1) * (uv_1 > 0)
|
||||
# tmp_cumsum = paddle.cumsum(rad_values, axis=1)
|
||||
# for idx in range(f0_values.shape[0]):
|
||||
# temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||
# temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||
# tmp_cumsum[idx, :, :] = 0
|
||||
# tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||
# i_phase = paddle.cumsum(rad_values - tmp_cumsum, axis=1)
|
||||
# sines = paddle.cos(i_phase * 2 * np.pi)
|
||||
# return sines
|
||||
|
||||
# def forward(self, f0):
|
||||
# """ sine_tensor, uv = forward(f0)
|
||||
# input F0: tensor(batchsize=1, length, axis=1)
|
||||
# f0 for unvoiced steps should be 0
|
||||
# output sine_tensor: tensor(batchsize=1, length, axis)
|
||||
# output uv: tensor(batchsize=1, length, 1)
|
||||
# """
|
||||
# paddle.seed(1986)
|
||||
# fn = paddle.multiply(f0, paddle.to_tensor([[range(1, self.harmonic_num +
|
||||
# 2)]],dtype='float32',place=f0.place))
|
||||
|
||||
# sine_waves = self._f02sine(fn) * self.sine_amp
|
||||
# uv = self._f02uv(f0)
|
||||
# noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
# noise = noise_amp * paddle.randn(shape=sine_waves.shape, dtype=
|
||||
# sine_waves.dtype)
|
||||
# sine_waves = sine_waves * uv + noise
|
||||
|
||||
# return sine_waves, uv, noise
|
||||
class SineGen2(paddle.nn.Layer):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0,
|
||||
flag_for_pulse=False,
|
||||
causal=False):
|
||||
super(SineGen2, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.dim = self.harmonic_num + 1
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
self.flag_for_pulse = flag_for_pulse
|
||||
self.upsample_scale = upsample_scale
|
||||
self.causal = causal
|
||||
paddle.seed(1986)
|
||||
if causal is True:
|
||||
self.rand_ini = paddle.rand(shape=[1, 9])
|
||||
self.rand_ini[:, 0] = 0
|
||||
self.sine_waves = paddle.rand(shape=[1, 300 * 24000, 9])
|
||||
self.register_buffer('rand_ini_buffer', self.rand_ini)
|
||||
self.register_buffer('sine_waves_buffer', self.sine_waves)
|
||||
|
||||
def _f02uv(self, f0):
|
||||
uv = (f0 > self.voiced_threshold).astype('float32')
|
||||
return uv
|
||||
|
||||
def _f02sine(self, f0_values):
|
||||
""" f0_values: (batchsize, length, dim) """
|
||||
rad_values = (f0_values / self.sampling_rate) % 1
|
||||
if not self.training and self.causal:
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini_buffer
|
||||
else:
|
||||
rand_ini = paddle.rand(shape=[f0_values.shape[0], f0_values.shape[2]])
|
||||
rand_ini[:, 0] = 0
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
|
||||
if not self.flag_for_pulse:
|
||||
scale_factor_down = 1.0 / self.upsample_scale
|
||||
|
||||
rad_values = paddle.nn.functional.interpolate(
|
||||
rad_values.transpose([0, 2, 1]),
|
||||
scale_factor=scale_factor_down,
|
||||
mode="linear"
|
||||
).transpose([0, 2, 1])
|
||||
phase = paddle.cumsum(rad_values, axis=1) * 2 * np.pi
|
||||
|
||||
interpolate_mode = "nearest" if self.causal else 'linear'
|
||||
phase = paddle.transpose(paddle.nn.functional.interpolate(
|
||||
paddle.transpose(phase,perm=[0, 2, 1])*self.upsample_scale,
|
||||
scale_factor=float(self.upsample_scale),
|
||||
mode=interpolate_mode
|
||||
),perm = [0, 2, 1])
|
||||
|
||||
sines = paddle.sin(phase)
|
||||
else:
|
||||
uv = self._f02uv(f0_values)
|
||||
|
||||
uv_1 = paddle.roll(uv, shifts=-1, axis=1)
|
||||
uv_1[:, -1, :] = 1
|
||||
u_loc = (uv < 1) * (uv_1 > 0)
|
||||
|
||||
tmp_cumsum = paddle.cumsum(rad_values, axis=1)
|
||||
for idx in range(f0_values.shape[0]):
|
||||
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||
tmp_cumsum[idx, :, :] = 0
|
||||
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||
|
||||
i_phase = paddle.cumsum(rad_values - tmp_cumsum, axis=1)
|
||||
sines = paddle.cos(i_phase * 2 * np.pi)
|
||||
|
||||
return sines
|
||||
|
||||
def forward(self, f0):
|
||||
""" sine_tensor, uv = forward(f0) """
|
||||
paddle.seed(1986)
|
||||
harmonic_coeffs = paddle.to_tensor(
|
||||
[list(range(1, self.harmonic_num + 2))],
|
||||
dtype='float32'
|
||||
).reshape([1, 1, -1])
|
||||
|
||||
fn = f0 * harmonic_coeffs
|
||||
|
||||
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||
uv = self._f02uv(f0)
|
||||
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
|
||||
if not self.training and self.causal:
|
||||
noise = noise_amp * self.sine_waves_buffer[:, :sine_waves.shape[1]]
|
||||
else:
|
||||
noise = noise_amp * paddle.randn(
|
||||
shape=sine_waves.shape,
|
||||
dtype=sine_waves.dtype
|
||||
)
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
class SourceModuleHnNSF2(paddle.nn.Layer):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0,
|
||||
sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
|
||||
super(SourceModuleHnNSF2, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale,
|
||||
harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
||||
self.l_linear = paddle.nn.Linear(in_features=harmonic_num + 1,
|
||||
out_features=1)
|
||||
self.l_tanh = paddle.nn.Tanh()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
"""
|
||||
paddle.seed(1986)
|
||||
with paddle.no_grad():
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
noise = paddle.randn(shape=uv.shape, dtype=uv.dtype
|
||||
) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
|
||||
|
||||
class HiFTGenerator(paddle.nn.Layer):
|
||||
"""
|
||||
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
||||
https://arxiv.org/abs/2309.09493
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int=80, base_channels: int=512,
|
||||
nb_harmonics: int=8, sampling_rate: int=22050, nsf_alpha: float=0.1,
|
||||
nsf_sigma: float=0.003, nsf_voiced_threshold: float=10,
|
||||
upsample_rates: List[int]=[8, 8], upsample_kernel_sizes: List[int]=
|
||||
[16, 16], istft_params: Dict[str, int]={'n_fft': 16, 'hop_len': 4},
|
||||
resblock_kernel_sizes: List[int]=[3, 7, 11],
|
||||
resblock_dilation_sizes: List[List[int]]=[[1, 3, 5], [1, 3, 5], [1,
|
||||
3, 5]], source_resblock_kernel_sizes: List[int]=[7, 11],
|
||||
source_resblock_dilation_sizes: List[List[int]]=[[1, 3, 5], [1, 3,
|
||||
5]], lrelu_slope: float=0.1, audio_limit: float=0.99, f0_predictor:
|
||||
paddle.nn.Layer=None):
|
||||
super(HiFTGenerator, self).__init__()
|
||||
self.out_channels = 1
|
||||
self.nb_harmonics = nb_harmonics
|
||||
self.sampling_rate = sampling_rate
|
||||
self.istft_params = istft_params
|
||||
self.lrelu_slope = lrelu_slope
|
||||
self.audio_limit = audio_limit
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=sampling_rate,
|
||||
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||
harmonic_num=nb_harmonics,
|
||||
sine_amp=nsf_alpha,
|
||||
add_noise_std=nsf_sigma,
|
||||
voiced_threshod=nsf_voiced_threshold,
|
||||
sinegen_type='1' if self.sampling_rate == 22050 else '2',
|
||||
causal=False)
|
||||
self.f0_upsamp = paddle.nn.Upsample(scale_factor=(1,int(np.prod( upsample_rates) * istft_params['hop_len'])))
|
||||
self.conv_pre = paddle.nn.Conv1D(in_channels = in_channels, out_channels = base_channels, kernel_size=7, stride=1, padding=3)
|
||||
self.ups = paddle.nn.LayerList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(paddle.nn.Conv1DTranspose(in_channels=base_channels // 2 ** i,out_channels=base_channels // 2 ** (i + 1), kernel_size=k,stride=u, padding=(k - u) // 2))
|
||||
self.source_downs = paddle.nn.LayerList()
|
||||
self.source_resblocks = paddle.nn.LayerList()
|
||||
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
||||
downsample_cum_rates = np.cumprod(downsample_rates)
|
||||
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1],
|
||||
source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
||||
if u == 1:
|
||||
self.source_downs.append(paddle.nn.Conv1D(istft_params[
|
||||
'n_fft'] + 2, base_channels // 2 ** (i + 1), 1, 1))
|
||||
else:
|
||||
self.source_downs.append(paddle.nn.Conv1D(
|
||||
in_channels=istft_params['n_fft'] + 2,
|
||||
out_channels=base_channels // (2 ** (i + 1)),
|
||||
kernel_size=(u * 2,),
|
||||
stride=(u,),
|
||||
padding=int(u // 2),
|
||||
))
|
||||
self.source_resblocks.append(ResBlock(base_channels // 2 ** (i +
|
||||
1), k, d))
|
||||
self.resblocks = paddle.nn.LayerList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = base_channels // 2 ** (i + 1)
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes,
|
||||
resblock_dilation_sizes)):
|
||||
self.resblocks.append(ResBlock(ch, k, d))
|
||||
self.conv_post = paddle.nn.Conv1D(ch, istft_params['n_fft'] + 2, 7, 1, padding=3)
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.reflection_pad = paddle.nn.Pad1D(padding=(1, 0), mode='reflect')
|
||||
self.stft_window = paddle.to_tensor(get_window('hann',
|
||||
istft_params['n_fft'], fftbins=True).astype(np.float32))
|
||||
self.f0_predictor = f0_predictor
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
paddle.nn.utils.remove_weight_norm(layer=l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
paddle.nn.utils.remove_weight_norm(layer=self.conv_pre)
|
||||
paddle.nn.utils.remove_weight_norm(layer=self.conv_post)
|
||||
self.m_source.remove_weight_norm()
|
||||
for l in self.source_downs:
|
||||
paddle.nn.utils.remove_weight_norm(layer=l)
|
||||
for l in self.source_resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
def _stft(self, x):
|
||||
spec = paddle.signal.stft(x=x, n_fft=self.istft_params['n_fft'],
|
||||
hop_length=self.istft_params['hop_len'], win_length=self.
|
||||
istft_params['n_fft'], window=self.stft_window.to(x.place))
|
||||
spec = paddle.as_real(spec)
|
||||
return spec[..., 0], spec[..., 1]
|
||||
|
||||
def _istft(self, magnitude, phase):
|
||||
magnitude = paddle.clip(magnitude, max=100.0)
|
||||
real = magnitude * paddle.cos(phase)
|
||||
img = magnitude * paddle.sin(phase)
|
||||
inverse_transform = paddle.signal.istft(x=paddle.complex(real, img),
|
||||
n_fft=self.istft_params['n_fft'], hop_length=self.istft_params[
|
||||
'hop_len'], win_length=self.istft_params['n_fft'], window=self.
|
||||
stft_window.to(magnitude.place))
|
||||
return inverse_transform
|
||||
|
||||
def decode(self, x: paddle.Tensor, s: paddle.Tensor=paddle.zeros([1, 1, 0])
|
||||
) ->paddle.Tensor:
|
||||
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||
s_stft = paddle.cat([s_stft_real, s_stft_imag], dim=1)
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = paddle.nn.functional.leaky_relu(x=x, negative_slope=self.
|
||||
lrelu_slope)
|
||||
x = self.ups[i](x)
|
||||
if i == self.num_upsamples - 1:
|
||||
x = self.reflection_pad(x)
|
||||
si = self.source_downs[i](s_stft)
|
||||
si = self.source_resblocks[i](si)
|
||||
x = x + si
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = paddle.nn.functional.leaky_relu(x=x)
|
||||
x = self.conv_post(x)
|
||||
magnitude = paddle.exp(x=x[:, :self.istft_params['n_fft'] // 2 + 1, :])
|
||||
phase = paddle.sin(x[:, self.istft_params['n_fft'] // 2 + 1:, :])
|
||||
x = self._istft(magnitude, phase)
|
||||
x = paddle.clip(x, -self.audio_limit, self.audio_limit)
|
||||
return x
|
||||
|
||||
def forward(self, batch: dict) ->Dict[str,
|
||||
Optional[paddle.Tensor]]:
|
||||
speech_feat = paddle.transpose(batch['speech_feat'],perm = [0,2,1]).to(device)
|
||||
f0 = self.f0_predictor(speech_feat)
|
||||
s = paddle.transpose(self.f0_upsamp(f0[:, None]),perm = [0,2,1])
|
||||
s, _, _ = self.m_source(s)
|
||||
s = paddle.transpose(s,perm = [0,2,1])
|
||||
|
||||
generated_speech = self.decode(x=speech_feat, s=s)
|
||||
return generated_speech, f0
|
||||
|
||||
@paddle.no_grad()
|
||||
def inference(self, speech_feat: paddle.Tensor, cache_source: paddle.
|
||||
Tensor=paddle.zeros([1, 1, 0])) ->paddle.Tensor:
|
||||
paddle.seed(1986)
|
||||
f0 = self.f0_predictor(speech_feat)
|
||||
f0_4d = f0[:, None].unsqueeze(2)
|
||||
s_4d = self.f0_upsamp(f0_4d)
|
||||
s_3d = s_4d.squeeze(2)
|
||||
s = paddle.transpose(s_3d, perm=[0, 2, 1])
|
||||
|
||||
s, _, _ = self.m_source(s)
|
||||
s = paddle.transpose(s,perm = [0,2,1])
|
||||
|
||||
if cache_source.shape[2] != 0:
|
||||
s[:, :, :cache_source.shape[2]] = cache_source
|
||||
generated_speech = self.decode(x=speech_feat, s=s)
|
||||
|
||||
return generated_speech, s
|
||||
@ -0,0 +1,27 @@
|
||||
import paddle
|
||||
class ConvRNNF0Predictor(paddle.nn.Layer):
|
||||
|
||||
def __init__(self, num_class: int=1, in_channels: int=80, cond_channels:
|
||||
int=512):
|
||||
super().__init__()
|
||||
self.num_class = num_class
|
||||
self.condnet = paddle.nn.Sequential(
|
||||
paddle.nn.Conv1D(in_channels, cond_channels, kernel_size=3, padding=1),
|
||||
paddle.nn.ELU(),
|
||||
paddle.nn.Conv1D(cond_channels, cond_channels,kernel_size=3, padding=1),
|
||||
paddle.nn.ELU(),
|
||||
paddle.nn.Conv1D(cond_channels, cond_channels,kernel_size=3, padding=1),
|
||||
paddle.nn.ELU(),
|
||||
paddle.nn.Conv1D(cond_channels, cond_channels,kernel_size=3, padding=1),
|
||||
paddle.nn.ELU(),
|
||||
paddle.nn.Conv1D(cond_channels, cond_channels,kernel_size=3, padding=1),
|
||||
paddle.nn.ELU()
|
||||
)
|
||||
self.classifier = paddle.nn.Linear(in_features=cond_channels,
|
||||
out_features=self.num_class)
|
||||
|
||||
def forward(self, x: paddle.Tensor) ->paddle.Tensor:
|
||||
for idx,layer in enumerate(self.condnet):
|
||||
x = layer(x)
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
return paddle.abs(x=self.classifier(x).squeeze(-1))
|
||||
@ -0,0 +1,9 @@
|
||||
class Transpose(torch.nn.Module):
|
||||
def __init__(self, dim0: int, dim1: int):
|
||||
super().__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.transpose(x, self.dim0, self.dim1)
|
||||
return x
|
||||
@ -0,0 +1 @@
|
||||
from .flow import CausalMaskedDiffWithXvec
|
||||
@ -0,0 +1,386 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import paddle
|
||||
from .attention_processor import Attention
|
||||
|
||||
def _chunked_feed_forward(
|
||||
ff: paddle.nn.Layer, hidden_states: paddle.Tensor, chunk_dim: int, chunk_size: int
|
||||
):
|
||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
||||
ff_output = paddle.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
return ff_output
|
||||
|
||||
class SinusoidalPositionalEmbedding(paddle.nn.Layer):
|
||||
"""Apply positional information to a sequence of embeddings.
|
||||
|
||||
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
|
||||
them
|
||||
|
||||
Args:
|
||||
embed_dim: (int): Dimension of the positional embedding.
|
||||
max_seq_length: Maximum sequence length to apply positional embeddings
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
||||
super().__init__()
|
||||
position = paddle.arange(max_seq_length).unsqueeze(1)
|
||||
div_term = paddle.exp(
|
||||
x=paddle.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
|
||||
)
|
||||
pe = paddle.zeros(1, max_seq_length, embed_dim)
|
||||
pe[0, :, 0::2] = paddle.sin(position * div_term)
|
||||
pe[0, :, 1::2] = paddle.cos(position * div_term)
|
||||
self.register_buffer(name="pe", tensor=pe)
|
||||
|
||||
def forward(self, x):
|
||||
_, seq_length, _ = x.shape
|
||||
x = x + self.pe[:, :seq_length]
|
||||
return x
|
||||
|
||||
class GatedSelfAttentionDense(paddle.nn.Layer):
|
||||
"""
|
||||
A gated self-attention dense layer that combines visual features and object features.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
context_dim (`int`): The number of channels in the context.
|
||||
n_heads (`int`): The number of heads to use for attention.
|
||||
d_head (`int`): The number of channels in each head.
|
||||
"""
|
||||
|
||||
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
||||
super().__init__()
|
||||
self.linear = paddle.nn.Linear(in_features=context_dim, out_features=query_dim)
|
||||
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
||||
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
||||
self.norm1 = paddle.nn.LayerNorm(normalized_shape=query_dim)
|
||||
self.norm2 = paddle.nn.LayerNorm(normalized_shape=query_dim)
|
||||
self.add_parameter(
|
||||
name="alpha_attn",
|
||||
parameter=paddle.nn.parameter.Parameter(paddle.tensor(0.0)),
|
||||
)
|
||||
self.add_parameter(
|
||||
name="alpha_dense",
|
||||
parameter=paddle.nn.parameter.Parameter(paddle.tensor(0.0)),
|
||||
)
|
||||
self.enabled = True
|
||||
|
||||
def forward(self, x: paddle.Tensor, objs: paddle.Tensor) -> paddle.Tensor:
|
||||
if not self.enabled:
|
||||
return x
|
||||
n_visual = x.shape[1]
|
||||
objs = self.linear(objs)
|
||||
x = (
|
||||
x
|
||||
+ self.alpha_attn.tanh()
|
||||
* self.attn(self.norm1(paddle.cat([x, objs], dim=1)))[:, :n_visual, :]
|
||||
)
|
||||
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlock(paddle.nn.Layer):
|
||||
"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_eps: float = 1e-05,
|
||||
final_dropout: bool = False,
|
||||
attention_type: str = "default",
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
||||
ada_norm_bias: Optional[int] = None,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
self.use_ada_layer_norm_zero = (
|
||||
num_embeds_ada_norm is not None and norm_type == "ada_norm_zero"
|
||||
)
|
||||
self.use_ada_layer_norm = (
|
||||
num_embeds_ada_norm is not None and norm_type == "ada_norm"
|
||||
)
|
||||
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
||||
self.use_layer_norm = norm_type == "layer_norm"
|
||||
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
self.norm_type = norm_type
|
||||
self.num_embeds_ada_norm = num_embeds_ada_norm
|
||||
if positional_embeddings and num_positional_embeddings is None:
|
||||
raise ValueError(
|
||||
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
||||
)
|
||||
if positional_embeddings == "sinusoidal":
|
||||
self.pos_embed = SinusoidalPositionalEmbedding(
|
||||
dim, max_seq_length=num_positional_embeddings
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
|
||||
self.norm1 = paddle.nn.LayerNorm(
|
||||
normalized_shape=dim,
|
||||
weight_attr=norm_elementwise_affine,
|
||||
bias_attr=norm_elementwise_affine,
|
||||
epsilon=norm_eps,
|
||||
)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
if cross_attention_dim is not None or double_self_attention:
|
||||
if norm_type == "ada_norm":
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif norm_type == "ada_norm_continuous":
|
||||
self.norm2 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"rms_norm",
|
||||
)
|
||||
else:
|
||||
self.norm2 = paddle.nn.LayerNorm(
|
||||
normalized_shape=dim,
|
||||
epsilon=norm_eps,
|
||||
weight_attr=norm_elementwise_affine,
|
||||
bias_attr=norm_elementwise_affine,
|
||||
)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim
|
||||
if not double_self_attention
|
||||
else None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
out_bias=attention_out_bias,
|
||||
)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
if norm_type == "ada_norm_continuous":
|
||||
self.norm3 = AdaLayerNormContinuous(
|
||||
dim,
|
||||
ada_norm_continous_conditioning_embedding_dim,
|
||||
norm_elementwise_affine,
|
||||
norm_eps,
|
||||
ada_norm_bias,
|
||||
"layer_norm",
|
||||
)
|
||||
elif norm_type in [
|
||||
"ada_norm_zero",
|
||||
"ada_norm",
|
||||
"layer_norm",
|
||||
"ada_norm_continuous",
|
||||
]:
|
||||
self.norm3 = paddle.nn.LayerNorm(
|
||||
normalized_shape=dim,
|
||||
epsilon=norm_eps,
|
||||
weight_attr=norm_elementwise_affine,
|
||||
bias_attr=norm_elementwise_affine,
|
||||
)
|
||||
elif norm_type == "layer_norm_i2vgen":
|
||||
self.norm3 = None
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dropout=dropout,
|
||||
activation_fn=activation_fn,
|
||||
final_dropout=final_dropout,
|
||||
inner_dim=ff_inner_dim,
|
||||
bias=ff_bias,
|
||||
)
|
||||
if attention_type == "gated" or attention_type == "gated-text-image":
|
||||
self.fuser = GatedSelfAttentionDense(
|
||||
dim, cross_attention_dim, num_attention_heads, attention_head_dim
|
||||
)
|
||||
if norm_type == "ada_norm_single":
|
||||
self.scale_shift_table = paddle.nn.parameter.Parameter(
|
||||
paddle.randn(6, dim) / dim**0.5
|
||||
)
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: paddle.Tensor,
|
||||
attention_mask: Optional[paddle.Tensor] = None,
|
||||
encoder_hidden_states: Optional[paddle.Tensor] = None,
|
||||
encoder_attention_mask: Optional[paddle.Tensor] = None,
|
||||
timestep: Optional[paddle.Tensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[paddle.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, paddle.Tensor]] = None,
|
||||
) -> paddle.Tensor:
|
||||
if cross_attention_kwargs is not None:
|
||||
if cross_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored."
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.norm_type == "ada_norm_zero":
|
||||
(norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm1(
|
||||
hidden_states, added_cond_kwargs["pooled_text_emb"]
|
||||
)
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
else:
|
||||
raise ValueError("Incorrect norm used")
|
||||
if self.pos_embed is not None:
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
cross_attention_kwargs = (
|
||||
cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
)
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states
|
||||
if self.only_cross_attention
|
||||
else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
attn_output = gate_msa * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
if self.attn2 is not None:
|
||||
if self.norm_type == "ada_norm":
|
||||
norm_hidden_states = self.norm2(hidden_states, timestep)
|
||||
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = hidden_states
|
||||
elif self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm2(
|
||||
hidden_states, added_cond_kwargs["pooled_text_emb"]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
||||
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
if self.norm_type == "ada_norm_continuous":
|
||||
norm_hidden_states = self.norm3(
|
||||
hidden_states, added_cond_kwargs["pooled_text_emb"]
|
||||
)
|
||||
elif not self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
norm_hidden_states = (
|
||||
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
)
|
||||
if self.norm_type == "ada_norm_single":
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
if self._chunk_size is not None:
|
||||
ff_output = _chunked_feed_forward(
|
||||
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
ff_output = gate_mlp * ff_output
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -0,0 +1,625 @@
|
||||
import inspect
|
||||
import math
|
||||
from typing import Callable, List, Optional, Union
|
||||
import paddle
|
||||
|
||||
class Attention(paddle.nn.Layer):
|
||||
"""
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`):
|
||||
The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
||||
heads (`int`, *optional*, defaults to 8):
|
||||
The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64):
|
||||
The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*, defaults to False):
|
||||
Set to `True` to upcast the attention computation to `float32`.
|
||||
upcast_softmax (`bool`, *optional*, defaults to False):
|
||||
Set to `True` to upcast the softmax computation to `float32`.
|
||||
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
||||
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
||||
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the group norm in the cross attention.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
The number of groups to use for the group norm in the attention.
|
||||
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the spatial normalization.
|
||||
out_bias (`bool`, *optional*, defaults to `True`):
|
||||
Set to `True` to use a bias in the output linear layer.
|
||||
scale_qk (`bool`, *optional*, defaults to `True`):
|
||||
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
||||
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
||||
`added_kv_proj_dim` is not `None`.
|
||||
eps (`float`, *optional*, defaults to 1e-5):
|
||||
An additional value added to the denominator in group normalization that is used for numerical stability.
|
||||
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
||||
A factor to rescale the output by dividing it with this value.
|
||||
residual_connection (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to add the residual connection to the output.
|
||||
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` if the attention block is loaded from a deprecated state dict.
|
||||
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
||||
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
||||
`AttnProcessor` otherwise.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
cross_attention_norm: Optional[str] = None,
|
||||
cross_attention_norm_num_groups: int = 32,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
spatial_norm_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-05,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
_from_deprecated_attn_block: bool = False,
|
||||
processor: Optional["AttnProcessor"] = None,
|
||||
out_dim: int = None,
|
||||
context_pre_only=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = (
|
||||
cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
)
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.sliceable_head_dim = heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.only_cross_attention = only_cross_attention
|
||||
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||
raise ValueError(
|
||||
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||
)
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = paddle.nn.GroupNorm(
|
||||
num_channels=query_dim,
|
||||
num_groups=norm_num_groups,
|
||||
epsilon=eps,
|
||||
weight_attr=True,
|
||||
bias_attr=True,
|
||||
)
|
||||
else:
|
||||
self.group_norm = None
|
||||
if spatial_norm_dim is not None:
|
||||
self.spatial_norm = SpatialNorm(
|
||||
f_channels=query_dim, zq_channels=spatial_norm_dim
|
||||
)
|
||||
else:
|
||||
self.spatial_norm = None
|
||||
if qk_norm is None:
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
elif qk_norm == "layer_norm":
|
||||
self.norm_q = paddle.nn.LayerNorm(normalized_shape=dim_head, epsilon=eps)
|
||||
self.norm_k = paddle.nn.LayerNorm(normalized_shape=dim_head, epsilon=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'"
|
||||
)
|
||||
if cross_attention_norm is None:
|
||||
self.norm_cross = None
|
||||
elif cross_attention_norm == "layer_norm":
|
||||
self.norm_cross = paddle.nn.LayerNorm(
|
||||
normalized_shape=self.cross_attention_dim
|
||||
)
|
||||
elif cross_attention_norm == "group_norm":
|
||||
if self.added_kv_proj_dim is not None:
|
||||
norm_cross_num_channels = added_kv_proj_dim
|
||||
else:
|
||||
norm_cross_num_channels = self.cross_attention_dim
|
||||
self.norm_cross = paddle.nn.GroupNorm(
|
||||
num_channels=norm_cross_num_channels,
|
||||
num_groups=cross_attention_norm_num_groups,
|
||||
epsilon=1e-05,
|
||||
weight_attr=True,
|
||||
bias_attr=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
||||
)
|
||||
self.to_q = paddle.nn.Linear(
|
||||
in_features=query_dim, out_features=self.inner_dim, bias_attr=bias
|
||||
)
|
||||
if not self.only_cross_attention:
|
||||
self.to_k = paddle.nn.Linear(
|
||||
in_features=self.cross_attention_dim,
|
||||
out_features=self.inner_dim,
|
||||
bias_attr=bias,
|
||||
)
|
||||
self.to_v = paddle.nn.Linear(
|
||||
in_features=self.cross_attention_dim,
|
||||
out_features=self.inner_dim,
|
||||
bias_attr=bias,
|
||||
)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = paddle.nn.Linear(
|
||||
in_features=added_kv_proj_dim, out_features=self.inner_dim
|
||||
)
|
||||
self.add_v_proj = paddle.nn.Linear(
|
||||
in_features=added_kv_proj_dim, out_features=self.inner_dim
|
||||
)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = paddle.nn.Linear(
|
||||
in_features=added_kv_proj_dim, out_features=self.inner_dim
|
||||
)
|
||||
self.to_out = paddle.nn.LayerList(sublayers=[])
|
||||
self.to_out.append(
|
||||
paddle.nn.Linear(
|
||||
in_features=self.inner_dim,
|
||||
out_features=self.out_dim,
|
||||
bias_attr=out_bias,
|
||||
)
|
||||
)
|
||||
self.to_out.append(paddle.nn.Dropout(p=dropout))
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = paddle.nn.Linear(
|
||||
in_features=self.inner_dim,
|
||||
out_features=self.out_dim,
|
||||
bias_attr=out_bias,
|
||||
)
|
||||
processor = AttnProcessor()
|
||||
processor = AttnProcessor2_0()
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor") -> None:
|
||||
"""
|
||||
Set the attention processor to use.
|
||||
|
||||
Args:
|
||||
processor (`AttnProcessor`):
|
||||
The attention processor to use.
|
||||
"""
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and isinstance(self.processor, paddle.nn.Layer)
|
||||
and not isinstance(processor, paddle.nn.Layer)
|
||||
):
|
||||
self._modules.pop("processor")
|
||||
self.processor = processor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: paddle.Tensor,
|
||||
encoder_hidden_states: Optional[paddle.Tensor] = None,
|
||||
attention_mask: Optional[paddle.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
The forward method of the `Attention` class.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
The hidden states of the query.
|
||||
encoder_hidden_states (`torch.Tensor`, *optional*):
|
||||
The hidden states of the encoder.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
The attention mask to use. If `None`, no mask is applied.
|
||||
**cross_attention_kwargs:
|
||||
Additional keyword arguments to pass along to the cross attention.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The output of the attention layer.
|
||||
"""
|
||||
attn_parameters = set(
|
||||
inspect.signature(self.processor.__call__).parameters.keys()
|
||||
)
|
||||
quiet_attn_parameters = {"ip_adapter_masks"}
|
||||
unused_kwargs = [
|
||||
k
|
||||
for k, _ in cross_attention_kwargs.items()
|
||||
if k not in attn_parameters and k not in quiet_attn_parameters
|
||||
]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
cross_attention_kwargs = {
|
||||
k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
|
||||
}
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
def batch_to_head_dim(self, tensor: paddle.Tensor) -> paddle.Tensor:
|
||||
"""
|
||||
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
||||
is the number of heads initialized while constructing the `Attention` class.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The reshaped tensor.
|
||||
"""
|
||||
head_size = self.heads
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
||||
batch_size // head_size, seq_len, dim * head_size
|
||||
)
|
||||
return tensor
|
||||
|
||||
def head_to_batch_dim(
|
||||
self, tensor: paddle.Tensor, out_dim: int = 3
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
||||
the number of heads initialized while constructing the `Attention` class.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
||||
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The reshaped tensor.
|
||||
"""
|
||||
head_size = self.heads
|
||||
if tensor.ndim == 3:
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
extra_dim = 1
|
||||
else:
|
||||
batch_size, extra_dim, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(
|
||||
batch_size, seq_len * extra_dim, head_size, dim // head_size
|
||||
)
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
if out_dim == 3:
|
||||
tensor = tensor.reshape(
|
||||
batch_size * head_size, seq_len * extra_dim, dim // head_size
|
||||
)
|
||||
return tensor
|
||||
|
||||
def get_attention_scores(
|
||||
self,
|
||||
query: paddle.Tensor,
|
||||
key: paddle.Tensor,
|
||||
attention_mask: paddle.Tensor = None,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Compute the attention scores.
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`): The query tensor.
|
||||
key (`torch.Tensor`): The key tensor.
|
||||
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The attention probabilities/scores.
|
||||
"""
|
||||
dtype = query.dtype
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
if attention_mask is None:
|
||||
baddbmm_input = paddle.empty(
|
||||
query.shape[0],
|
||||
query.shape[1],
|
||||
key.shape[1],
|
||||
dtype=query.dtype,
|
||||
device=query.place,
|
||||
)
|
||||
beta = 0
|
||||
else:
|
||||
baddbmm_input = attention_mask
|
||||
beta = 1
|
||||
attention_scores = paddle.baddbmm(
|
||||
input=baddbmm_input,
|
||||
x=query,
|
||||
y=key.transpose(-1, -2),
|
||||
beta=beta,
|
||||
alpha=self.scale,
|
||||
)
|
||||
del baddbmm_input
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
del attention_scores
|
||||
attention_probs = attention_probs.to(dtype)
|
||||
return attention_probs
|
||||
|
||||
def prepare_attention_mask(
|
||||
self,
|
||||
attention_mask: paddle.Tensor,
|
||||
target_length: int,
|
||||
batch_size: int,
|
||||
out_dim: int = 3,
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Prepare the attention mask for the attention computation.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
The attention mask to prepare.
|
||||
target_length (`int`):
|
||||
The target length of the attention mask. This is the length of the attention mask after padding.
|
||||
batch_size (`int`):
|
||||
The batch size, which is used to repeat the attention mask.
|
||||
out_dim (`int`, *optional*, defaults to `3`):
|
||||
The output dimension of the attention mask. Can be either `3` or `4`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The prepared attention mask.
|
||||
"""
|
||||
head_size = self.heads
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
current_length: int = attention_mask.shape[-1]
|
||||
if current_length != target_length:
|
||||
if attention_mask.device.type == "mps":
|
||||
padding_shape = (
|
||||
attention_mask.shape[0],
|
||||
attention_mask.shape[1],
|
||||
target_length,
|
||||
)
|
||||
padding = paddle.zeros(
|
||||
padding_shape,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.place,
|
||||
)
|
||||
attention_mask = paddle.cat([attention_mask, padding], dim=2)
|
||||
else:
|
||||
attention_mask = paddle.compat.pad(
|
||||
attention_mask, (0, target_length), value=0.0
|
||||
)
|
||||
if out_dim == 3:
|
||||
if attention_mask.shape[0] < batch_size * head_size:
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, axis=0)
|
||||
elif out_dim == 4:
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
||||
return attention_mask
|
||||
|
||||
def norm_encoder_hidden_states(
|
||||
self, encoder_hidden_states: paddle.Tensor
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
||||
`Attention` class.
|
||||
|
||||
Args:
|
||||
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The normalized encoder hidden states.
|
||||
"""
|
||||
assert (
|
||||
self.norm_cross is not None
|
||||
), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
||||
if isinstance(self.norm_cross, paddle.nn.LayerNorm):
|
||||
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||
elif isinstance(self.norm_cross, paddle.nn.GroupNorm):
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
else:
|
||||
assert False
|
||||
return encoder_hidden_states
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse_projections(self, fuse=True):
|
||||
device = self.to_q.weight.data.place
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
if not self.is_cross_attention:
|
||||
concatenated_weights = paddle.cat(
|
||||
[self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
|
||||
)
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
self.to_qkv = paddle.nn.Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias_attr=self.use_bias,
|
||||
)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = paddle.cat(
|
||||
[self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]
|
||||
)
|
||||
self.to_qkv.bias.copy_(concatenated_bias)
|
||||
else:
|
||||
concatenated_weights = paddle.cat(
|
||||
[self.to_k.weight.data, self.to_v.weight.data]
|
||||
)
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
self.to_kv = paddle.nn.Linear(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias_attr=self.use_bias,
|
||||
)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = paddle.cat(
|
||||
[self.to_k.bias.data, self.to_v.bias.data]
|
||||
)
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
self.fused_projections = fuse
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: paddle.Tensor,
|
||||
encoder_hidden_states: Optional[paddle.Tensor] = None,
|
||||
attention_mask: Optional[paddle.Tensor] = None,
|
||||
temb: Optional[paddle.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> paddle.Tensor:
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size, channel, height * width
|
||||
).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape
|
||||
if encoder_hidden_states is None
|
||||
else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size
|
||||
)
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
)
|
||||
query = attn.to_q(hidden_states)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||
encoder_hidden_states
|
||||
)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = paddle.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
||||
batch_size, channel, height, width
|
||||
)
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
class AttnProcessor2_0:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: paddle.Tensor,
|
||||
encoder_hidden_states: Optional[paddle.Tensor] = None,
|
||||
attention_mask: Optional[paddle.Tensor] = None,
|
||||
temb: Optional[paddle.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> paddle.Tensor:
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size, channel, height * width
|
||||
).transpose(1, 2)
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape
|
||||
if encoder_hidden_states is None
|
||||
else encoder_hidden_states.shape
|
||||
)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size
|
||||
)
|
||||
attention_mask = attention_mask.view(
|
||||
[batch_size, attn.heads, -1, attention_mask.shape[-1]]
|
||||
)
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
)
|
||||
query = attn.to_q(hidden_states)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||
encoder_hidden_states
|
||||
)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = paddle.transpose(query.view([batch_size, -1, attn.heads, head_dim]),perm = [0,2,1])
|
||||
key = paddle.transpose(key.view([batch_size, -1, attn.heads, head_dim]),perm = [0,2,1])
|
||||
value =paddle.transpose(value.view([batch_size, -1, attn.heads, head_dim]),perm = [0,2,1])
|
||||
hidden_states = paddle.nn.functional.scaled_dot_product_attention(
|
||||
query.transpose([0, 2, 1, 3]),
|
||||
key.transpose([0, 2, 1, 3]),
|
||||
value.transpose([0, 2, 1, 3]),
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
).transpose([0, 2, 1, 3])
|
||||
hidden_states = paddle.transpose(hidden_states,perm = [0,2,1]).reshape(
|
||||
[batch_size, -1, attn.heads * head_dim]
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if input_ndim == 4:
|
||||
hidden_states =paddle.transpose(hidden_states,perm = [0,1,3,2]).reshape(
|
||||
[batch_size, channel, height, width]
|
||||
)
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
return hidden_states
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,99 @@
|
||||
import paddle
|
||||
|
||||
"""ConvolutionModule definition."""
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class ConvolutionModule(paddle.nn.Layer):
|
||||
"""ConvolutionModule in Conformer model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 15,
|
||||
activation: paddle.nn.Layer = paddle.nn.ReLU(),
|
||||
norm: str = "batch_norm",
|
||||
causal: bool = False,
|
||||
bias: bool = True,
|
||||
):
|
||||
"""Construct an ConvolutionModule object.
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernel size of conv layers.
|
||||
causal (int): Whether use causal convolution or not
|
||||
"""
|
||||
super().__init__()
|
||||
self.pointwise_conv1 = paddle.nn.Conv1d(
|
||||
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias
|
||||
)
|
||||
if causal:
|
||||
padding = 0
|
||||
self.lorder = kernel_size - 1
|
||||
else:
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.lorder = 0
|
||||
self.depthwise_conv = paddle.nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
assert norm in ["batch_norm", "layer_norm"]
|
||||
if norm == "batch_norm":
|
||||
self.use_layer_norm = False
|
||||
self.norm = paddle.nn.BatchNorm1D(num_features=channels)
|
||||
else:
|
||||
self.use_layer_norm = True
|
||||
self.norm = paddle.nn.LayerNorm(normalized_shape=channels)
|
||||
self.pointwise_conv2 = paddle.nn.Conv1d(
|
||||
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool),
|
||||
cache: paddle.Tensor = paddle.zeros((0, 0, 0)),
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Compute convolution module.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, channels).
|
||||
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
||||
(0, 0, 0) means fake mask.
|
||||
cache (torch.Tensor): left context cache, it is only
|
||||
used in causal convolution (#batch, channels, cache_t),
|
||||
(0, 0, 0) meas fake cache.
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, channels).
|
||||
"""
|
||||
x = x.transpose(1, 2)
|
||||
if mask_pad.size(2) > 0:
|
||||
x.masked_fill_(~mask_pad, 0.0)
|
||||
if self.lorder > 0:
|
||||
if cache.size(2) == 0:
|
||||
x = paddle.compat.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||
else:
|
||||
assert cache.size(0) == x.size(0)
|
||||
assert cache.size(1) == x.size(1)
|
||||
x = paddle.cat((cache, x), dim=2)
|
||||
assert x.size(2) > self.lorder
|
||||
new_cache = x[:, :, -self.lorder :]
|
||||
else:
|
||||
new_cache = paddle.zeros((0, 0, 0), dtype=x.dtype, device=x.place)
|
||||
x = self.pointwise_conv1(x)
|
||||
x = paddle.nn.functional.glu(x=x, axis=1)
|
||||
x = self.depthwise_conv(x)
|
||||
if self.use_layer_norm:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.activation(self.norm(x))
|
||||
if self.use_layer_norm:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.pointwise_conv2(x)
|
||||
if mask_pad.size(2) > 0:
|
||||
x.masked_fill_(~mask_pad, 0.0)
|
||||
return x.transpose(1, 2), new_cache
|
||||
@ -0,0 +1,762 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Any, Dict, Optional
|
||||
import paddle
|
||||
import math
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from einops import pack, rearrange, repeat
|
||||
from paddlespeech.t2s.models.CosyVoice.common import mask_to_bias
|
||||
from paddlespeech.t2s.models.CosyVoice.mask import add_optional_chunk_mask
|
||||
from .matcha_transformer import BasicTransformerBlock
|
||||
|
||||
def get_activation(act_fn):
|
||||
if act_fn == "silu":
|
||||
return nn.Silu()
|
||||
elif act_fn == "mish":
|
||||
return nn.Mish()
|
||||
elif act_fn == "relu":
|
||||
return nn.ReLU()
|
||||
elif act_fn == "gelu":
|
||||
return nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
class Block1D(nn.Layer):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1D(dim, dim_out, 3, padding=1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
class ResnetBlock1D(nn.Layer):
|
||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(time_emb_dim, dim_out)
|
||||
)
|
||||
|
||||
self.block1 = Block1D(dim, dim_out, groups=groups)
|
||||
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
||||
self.res_conv = nn.Conv1D(dim, dim_out, 1)
|
||||
|
||||
def forward(self, x, mask, time_emb):
|
||||
h = self.block1(x, mask)
|
||||
# 添加时间嵌入并调整维度
|
||||
h += self.mlp(time_emb).unsqueeze(-1)
|
||||
h = self.block2(h, mask)
|
||||
output = h + self.res_conv(x * mask)
|
||||
return output
|
||||
|
||||
class Downsample1D(nn.Layer):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1D(dim, dim, 3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class TimestepEmbedding(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None and self.cond_proj is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
class Upsample1D(nn.Layer):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.Conv1DTranspose(channels, self.out_channels, 4, stride=2, padding=1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1D(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
class Transpose(nn.Layer):
|
||||
def __init__(self, dim0: int, dim1: int):
|
||||
super().__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
x = paddle.transpose(x, [0, self.dim1, self.dim0])
|
||||
return x
|
||||
|
||||
class CausalConv1d(nn.Conv1D):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
padding_mode: str = 'zeros'
|
||||
) -> None:
|
||||
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
||||
kernel_size, stride,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups,
|
||||
padding_mode=padding_mode)
|
||||
assert stride == 1
|
||||
self.causal_padding = kernel_size - 1
|
||||
|
||||
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
x = super(CausalConv1d, self).forward(x)
|
||||
return x
|
||||
|
||||
class SinusoidalPosEmb(paddle.nn.Layer):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = paddle.exp(paddle.arange(half_dim).astype('float32') * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1)
|
||||
return emb
|
||||
|
||||
class CausalBlock1D(Block1D):
|
||||
def __init__(self, dim: int, dim_out: int):
|
||||
super(CausalBlock1D, self).__init__(dim, dim_out)
|
||||
self.block = nn.Sequential(
|
||||
CausalConv1d(dim, dim_out, 3),
|
||||
Transpose(1, 2),
|
||||
nn.LayerNorm(dim_out),
|
||||
Transpose(1, 2),
|
||||
nn.Mish()
|
||||
)
|
||||
|
||||
def forward(self, x: paddle.Tensor, mask: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
class CausalResnetBlock1D(ResnetBlock1D):
|
||||
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
||||
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
||||
self.block1 = CausalBlock1D(dim, dim_out)
|
||||
self.block2 = CausalBlock1D(dim_out, dim_out)
|
||||
|
||||
def subsequent_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
) -> paddle.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
pos_idx = paddle.arange(size, dtype='int64')
|
||||
block_value = (paddle.floor_divide(pos_idx, chunk_size) + 1) * chunk_size
|
||||
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def add_optional_chunk_mask(xs: paddle.Tensor,
|
||||
masks: paddle.Tensor,
|
||||
use_dynamic_chunk: bool,
|
||||
use_dynamic_left_chunk: bool,
|
||||
decoding_chunk_size: int,
|
||||
static_chunk_size: int,
|
||||
num_decoding_left_chunks: int,
|
||||
enable_full_context: bool = True):
|
||||
""" Apply optional mask for encoder.
|
||||
|
||||
Args:
|
||||
xs (paddle.Tensor): padded input, (B, L, D), L for max length
|
||||
mask (paddle.Tensor): mask for xs, (B, 1, L)
|
||||
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
||||
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
||||
training.
|
||||
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
static_chunk_size (int): chunk size for static chunk training/decoding
|
||||
if it's greater than 0, if use_dynamic_chunk is true,
|
||||
this parameter will be ignored
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
enable_full_context (bool):
|
||||
True: chunk size is either [1, 25] or full context(max_len)
|
||||
False: chunk size ~ U[1, 25]
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: chunk mask of the input xs.
|
||||
"""
|
||||
# Whether to use chunk mask or not
|
||||
if use_dynamic_chunk:
|
||||
max_len = xs.shape[1]
|
||||
if decoding_chunk_size < 0:
|
||||
chunk_size = max_len
|
||||
num_left_chunks = -1
|
||||
elif decoding_chunk_size > 0:
|
||||
chunk_size = decoding_chunk_size
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
else:
|
||||
# chunk size is either [1, 25] or full context(max_len).
|
||||
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
||||
# delay, the maximum frame is 100 / 4 = 25.
|
||||
chunk_size = paddle.randint(1, max_len, shape=(1,)).item()
|
||||
num_left_chunks = -1
|
||||
if chunk_size > max_len // 2 and enable_full_context:
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % 25 + 1
|
||||
if use_dynamic_left_chunk:
|
||||
max_left_chunks = (max_len - 1) // chunk_size
|
||||
num_left_chunks = paddle.randint(0, max_left_chunks, shape=(1,)).item()
|
||||
|
||||
chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
|
||||
num_left_chunks) # (L, L)
|
||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||
elif static_chunk_size > 0:
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
|
||||
num_left_chunks) # (L, L)
|
||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||
else:
|
||||
chunk_masks = masks
|
||||
assert chunk_masks.dtype == paddle.bool
|
||||
if (chunk_masks.sum(axis=-1) == 0).sum().item() != 0:
|
||||
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in future computation!')
|
||||
all_false_mask = chunk_masks.sum(axis=-1) == 0
|
||||
chunk_masks = paddle.where(all_false_mask.unsqueeze(-1), paddle.ones_like(chunk_masks, dtype='bool'), chunk_masks)
|
||||
|
||||
return chunk_masks
|
||||
|
||||
def mask_to_bias(mask: paddle.Tensor, dtype: str) -> paddle.Tensor:
|
||||
assert mask.dtype == paddle.bool, "Input mask must be of boolean type"
|
||||
assert dtype in [paddle.float32, paddle.bfloat16, paddle.float16], f"Unsupported dtype: {dtype}"
|
||||
mask = mask.astype(dtype)
|
||||
mask = (1.0 - mask) * -1.0e+10
|
||||
|
||||
return mask
|
||||
|
||||
class ConditionalDecoder(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
):
|
||||
"""
|
||||
This decoder requires an input with the same shape of the target. So, if your text content
|
||||
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
||||
"""
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
self.down_blocks = nn.LayerList([])
|
||||
self.mid_blocks = nn.LayerList([])
|
||||
self.up_blocks = nn.LayerList([])
|
||||
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.LayerList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else nn.Conv1D(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
self.down_blocks.append(nn.LayerList([resnet, transformer_blocks, downsample]))
|
||||
|
||||
for _ in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
|
||||
transformer_blocks = nn.LayerList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.mid_blocks.append(nn.LayerList([resnet, transformer_blocks]))
|
||||
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i] * 2
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = nn.LayerList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True)
|
||||
if not is_last
|
||||
else nn.Conv1D(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
self.up_blocks.append(nn.LayerList([resnet, transformer_blocks, upsample]))
|
||||
self.final_block = Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1D(channels[-1], self.out_channels, 1)
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.sublayers():
|
||||
if isinstance(m, nn.Conv1D):
|
||||
nn.initializer.KaimingNormal(m.weight, nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.initializer.Constant(m.bias, value=0)
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.initializer.Constant(m.weight, value=1)
|
||||
nn.initializer.Constant(m.bias, value=0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.initializer.KaimingNormal(m.weight, nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.initializer.Constant(m.bias, value=0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (paddle.Tensor): shape (batch_size, 1, time)
|
||||
t (paddle.Tensor): shape (batch_size)
|
||||
spks (paddle.Tensor, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (paddle.Tensor, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: output tensor
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t).astype(t.dtype)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
if cond is not None:
|
||||
x = pack([x, cond], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x = resnet(x, mask_down, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
attn_mask = add_optional_chunk_mask(x, mask_down.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
attn_mask = add_optional_chunk_mask(x, mask_mid.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
skip = hiddens.pop()
|
||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||
x = resnet(x, mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
attn_mask = add_optional_chunk_mask(x, mask_up.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
x = upsample(x * mask_up)
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
return output * mask
|
||||
|
||||
|
||||
class CausalConditionalDecoder(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
static_chunk_size=50,
|
||||
num_decoding_left_chunks=2,
|
||||
):
|
||||
"""
|
||||
This decoder requires an input with the same shape of the target. So, if your text content
|
||||
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
||||
"""
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
self.static_chunk_size = static_chunk_size
|
||||
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||
self.down_blocks = nn.LayerList([])
|
||||
self.mid_blocks = nn.LayerList([])
|
||||
self.up_blocks = nn.LayerList([])
|
||||
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.LayerList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) # 假设已实现
|
||||
)
|
||||
self.down_blocks.append(nn.LayerList([resnet, transformer_blocks, downsample]))
|
||||
|
||||
for _ in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
|
||||
transformer_blocks = nn.LayerList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.mid_blocks.append(nn.LayerList([resnet, transformer_blocks]))
|
||||
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i] * 2
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
resnet = CausalResnetBlock1D(
|
||||
dim=input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = nn.LayerList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True) # 假设已实现
|
||||
if not is_last
|
||||
else CausalConv1d(output_channel, output_channel, 3)
|
||||
)
|
||||
self.up_blocks.append(nn.LayerList([resnet, transformer_blocks, upsample]))
|
||||
self.final_block = CausalBlock1D(channels[-1], channels[-1]) # 假设已实现
|
||||
self.final_proj = nn.Conv1D(channels[-1], self.out_channels, 1) # 使用 Conv1D
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.sublayers(): # 使用 sublayers() 而不是 modules()
|
||||
if isinstance(m, nn.Conv1D):
|
||||
nn.initializer.KaimingNormal(m.weight, nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
initializer = nn.initializer.Constant(value=0)
|
||||
initializer(m.bias)
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.initializer.Constant(m.weight, value=1)
|
||||
nn.initializer.Constant(m.bias, value=0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.initializer.KaimingNormal(m.weight, nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
initializer = nn.initializer.Constant(value=0)
|
||||
initializer(m.bias)
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (paddle.Tensor): shape (batch_size, 1, time)
|
||||
mu (paddle.Tensor): mean tensor for conditioning
|
||||
t (paddle.Tensor): shape (batch_size)
|
||||
spks (paddle.Tensor, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (paddle.Tensor, optional): placeholder for future use. Defaults to None.
|
||||
streaming (bool, optional): whether to use streaming mode. Defaults to False.
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: output tensor
|
||||
"""
|
||||
t = self.time_embeddings(t).astype(t.dtype) # 使用 astype 代替 .to(t.dtype)
|
||||
t = self.time_mlp(t)
|
||||
x = pack([x, mu], "b * t")[0] # 假设 pack 函数已实现
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) # 假设 repeat 函数已实现
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
if cond is not None:
|
||||
x = pack([x, cond], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x = resnet(x, mask_down, t)
|
||||
|
||||
x = rearrange(x, "b c t -> b t c").contiguous() # 假设 rearrange 函数已实现
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_down.astype('bool'), False, False, 0, self.static_chunk_size, -1) # 使用 astype('bool')
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_down.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1) # 使用 .shape 而不是 .size()
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype) # 假设 mask_to_bias 函数已实现
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_mid.astype('bool'), False, False, 0, self.static_chunk_size, -1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_mid.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
skip = hiddens.pop()
|
||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||
x = resnet(x, mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_up.astype('bool'), False, False, 0, self.static_chunk_size, -1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_up.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
x = upsample(x * mask_up)
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
return output * mask
|
||||
@ -0,0 +1,132 @@
|
||||
import paddle
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils.import_utils import is_torch_npu_available
|
||||
|
||||
if is_torch_npu_available():
|
||||
import torch_npu
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
"swish": paddle.nn.SiLU(),
|
||||
"silu": paddle.nn.SiLU(),
|
||||
"mish": paddle.nn.Mish(),
|
||||
"gelu": paddle.nn.GELU(),
|
||||
"relu": paddle.nn.ReLU(),
|
||||
}
|
||||
|
||||
|
||||
def get_activation(act_fn: str) -> paddle.nn.Layer:
|
||||
"""Helper function to get activation function from string.
|
||||
|
||||
Args:
|
||||
act_fn (str): Name of activation function.
|
||||
|
||||
Returns:
|
||||
nn.Module: Activation function.
|
||||
"""
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in ACTIVATION_FUNCTIONS:
|
||||
return ACTIVATION_FUNCTIONS[act_fn]
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
|
||||
class FP32SiLU(paddle.nn.Layer):
|
||||
"""
|
||||
SiLU activation function with input upcasted to torch.float32.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
|
||||
return paddle.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype)
|
||||
|
||||
|
||||
class GELU(paddle.nn.Layer):
|
||||
"""
|
||||
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = paddle.nn.Linear(
|
||||
in_features=dim_in, out_features=dim_out, bias_attr=bias
|
||||
)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate: paddle.Tensor) -> paddle.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
return paddle.nn.functional.gelu(gate, approximate=self.approximate)
|
||||
return paddle.nn.functional.gelu(
|
||||
gate.to(dtype=paddle.float32), approximate=self.approximate
|
||||
).to(dtype=gate.dtype)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GEGLU(paddle.nn.Layer):
|
||||
"""
|
||||
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = paddle.nn.Linear(
|
||||
in_features=dim_in, out_features=dim_out * 2, bias_attr=bias
|
||||
)
|
||||
|
||||
def gelu(self, gate: paddle.Tensor) -> paddle.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
return paddle.nn.functional.gelu(gate)
|
||||
return paddle.nn.functional.gelu(gate.to(dtype=paddle.float32)).to(
|
||||
dtype=gate.dtype
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, *args, **kwargs):
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
||||
deprecate("scale", "1.0.0", deprecation_message)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
if is_torch_npu_available():
|
||||
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]
|
||||
else:
|
||||
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
class ApproximateGELU(paddle.nn.Layer):
|
||||
"""
|
||||
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
|
||||
[paper](https://arxiv.org/abs/1606.08415).
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = paddle.nn.Linear(
|
||||
in_features=dim_in, out_features=dim_out, bias_attr=bias
|
||||
)
|
||||
|
||||
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
x = self.proj(x)
|
||||
return x * paddle.nn.functional.sigmoid(1.702 * x)
|
||||
@ -0,0 +1,326 @@
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, Optional
|
||||
|
||||
import paddle
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from paddlespeech.t2s.models.CosyVoice.mask import make_pad_mask
|
||||
|
||||
|
||||
class MaskedDiffWithXvec(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 4096,
|
||||
input_frame_rate: int = 50,
|
||||
only_mask_loss: bool = True,
|
||||
encoder: paddle.nn.Layer = None,
|
||||
length_regulator: paddle.nn.Layer = None,
|
||||
decoder: paddle.nn.Layer = None,
|
||||
decoder_conf: Dict = {
|
||||
"in_channels": 240,
|
||||
"out_channel": 80,
|
||||
"spk_emb_dim": 80,
|
||||
"n_spks": 1,
|
||||
"cfm_params": DictConfig(
|
||||
{
|
||||
"sigma_min": 1e-06,
|
||||
"solver": "euler",
|
||||
"t_scheduler": "cosine",
|
||||
"training_cfg_rate": 0.2,
|
||||
"inference_cfg_rate": 0.7,
|
||||
"reg_loss_type": "l1",
|
||||
}
|
||||
),
|
||||
"decoder_params": {
|
||||
"channels": [256, 256],
|
||||
"dropout": 0.0,
|
||||
"attention_head_dim": 64,
|
||||
"n_blocks": 4,
|
||||
"num_mid_blocks": 12,
|
||||
"num_heads": 8,
|
||||
"act_fn": "gelu",
|
||||
},
|
||||
},
|
||||
mel_feat_conf: Dict = {
|
||||
"n_fft": 1024,
|
||||
"num_mels": 80,
|
||||
"sampling_rate": 22050,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
"fmin": 0,
|
||||
"fmax": 8000,
|
||||
},
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.mel_feat_conf = mel_feat_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||
self.input_embedding = paddle.nn.Embedding(vocab_size, input_size)
|
||||
self.spk_embed_affine_layer = paddle.nn.Linear(
|
||||
in_features=spk_embed_dim, out_features=output_size
|
||||
)
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = paddle.nn.Linear(
|
||||
in_features=self.encoder.output_size(), out_features=output_size
|
||||
)
|
||||
self.decoder = decoder
|
||||
self.length_regulator = length_regulator
|
||||
self.only_mask_loss = only_mask_loss
|
||||
|
||||
def forward(
|
||||
self, batch: dict, device: paddle.device
|
||||
) -> Dict[str, Optional[paddle.Tensor]]:
|
||||
token = batch["speech_token"].to(device)
|
||||
token_len = batch["speech_token_len"].to(device)
|
||||
feat = batch["speech_feat"].to(device)
|
||||
feat_len = batch["speech_feat_len"].to(device)
|
||||
embedding = batch["embedding"].to(device)
|
||||
embedding = paddle.nn.functional.normalize(x=embedding, axis=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||
token = self.input_embedding(paddle.clip(token, min=0)) * mask
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
h, h_lengths = self.length_regulator(h, feat_len)
|
||||
conds = paddle.zeros(feat.shape, device=token.place)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :index] = feat[i, :index]
|
||||
conds = paddle.transpose(conds, perm=[0, 2, 1])
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
paddle.transpose(feat, perm=[0, 2, 1]),
|
||||
mask.unsqueeze(1),
|
||||
paddle.transpose(h, perm=[0, 2, 1]),
|
||||
embedding,
|
||||
cond=conds,
|
||||
)
|
||||
return {"loss": loss}
|
||||
|
||||
@paddle.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
token,
|
||||
token_len,
|
||||
prompt_token,
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
flow_cache,
|
||||
):
|
||||
assert token.shape[0] == 1
|
||||
embedding = paddle.nn.functional.normalize(x=embedding, axis=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||
token, token_len = (
|
||||
paddle.cat([prompt_token, token], dim=1),
|
||||
prompt_token_len + token_len,
|
||||
)
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(paddle.clip(token, min=0)) * mask
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], int(
|
||||
token_len2 / self.input_frame_rate * 22050 / 256
|
||||
)
|
||||
h, h_lengths = self.length_regulator.inference(
|
||||
h[:, :token_len1],
|
||||
h[:, token_len1:],
|
||||
mel_len1,
|
||||
mel_len2,
|
||||
self.input_frame_rate,
|
||||
)
|
||||
conds = paddle.zeros(
|
||||
[1, mel_len1 + mel_len2, self.output_size], device=token.place
|
||||
).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = paddle.transpose(conds, perm=[0, 2, 1])
|
||||
|
||||
mask = (~make_pad_mask(paddle.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat, flow_cache = self.decoder(
|
||||
mu=paddle.transpose(h, perm=[0, 2, 1]),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
prompt_len=mel_len1,
|
||||
cache=flow_cache,
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), flow_cache
|
||||
|
||||
|
||||
class CausalMaskedDiffWithXvec(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 6561,
|
||||
input_frame_rate: int = 25,
|
||||
only_mask_loss: bool = True,
|
||||
token_mel_ratio: int = 2,
|
||||
pre_lookahead_len: int = 3,
|
||||
encoder: paddle.nn.Layer = None,
|
||||
decoder: paddle.nn.Layer = None,
|
||||
decoder_conf: Dict = {
|
||||
"in_channels": 240,
|
||||
"out_channel": 80,
|
||||
"spk_emb_dim": 80,
|
||||
"n_spks": 1,
|
||||
"cfm_params": DictConfig(
|
||||
{
|
||||
"sigma_min": 1e-06,
|
||||
"solver": "euler",
|
||||
"t_scheduler": "cosine",
|
||||
"training_cfg_rate": 0.2,
|
||||
"inference_cfg_rate": 0.7,
|
||||
"reg_loss_type": "l1",
|
||||
}
|
||||
),
|
||||
"decoder_params": {
|
||||
"channels": [256, 256],
|
||||
"dropout": 0.0,
|
||||
"attention_head_dim": 64,
|
||||
"n_blocks": 4,
|
||||
"num_mid_blocks": 12,
|
||||
"num_heads": 8,
|
||||
"act_fn": "gelu",
|
||||
},
|
||||
},
|
||||
mel_feat_conf: Dict = {
|
||||
"n_fft": 1024,
|
||||
"num_mels": 80,
|
||||
"sampling_rate": 22050,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
"fmin": 0,
|
||||
"fmax": 8000,
|
||||
},
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.mel_feat_conf = mel_feat_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||
self.input_embedding = paddle.nn.Embedding(vocab_size, input_size)
|
||||
self.spk_embed_affine_layer = paddle.nn.Linear(
|
||||
in_features=spk_embed_dim, out_features=output_size
|
||||
)
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = paddle.nn.Linear(
|
||||
in_features=self.encoder.output_size(), out_features=output_size
|
||||
)
|
||||
self.decoder = decoder
|
||||
self.only_mask_loss = only_mask_loss
|
||||
self.token_mel_ratio = token_mel_ratio
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
|
||||
def forward(
|
||||
self, batch: dict, device: paddle.device
|
||||
) -> Dict[str, Optional[paddle.Tensor]]:
|
||||
token = batch["speech_token"].to(device)
|
||||
token_len = batch["speech_token_len"].to(device)
|
||||
feat = batch["speech_feat"].to(device)
|
||||
feat_len = batch["speech_feat_len"].to(device)
|
||||
embedding = batch["embedding"].to(device)
|
||||
streaming = True if random.random() < 0.5 else False
|
||||
embedding = paddle.nn.functional.normalize(x=embedding, axis=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||
token = self.input_embedding(paddle.clip(token, min=0)) * mask
|
||||
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||
h = self.encoder_proj(h)
|
||||
conds = paddle.zeros(feat.shape, device=token.place)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :index] = feat[i, :index]
|
||||
conds = paddle.transpose(conds, perm=[0, 2, 1])
|
||||
|
||||
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
paddle.transpose(feat, perm=[0, 2, 1]).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
paddle.transpose(h, perm=[0, 2, 1]).contiguous(),
|
||||
embedding,
|
||||
cond=conds,
|
||||
streaming=streaming,
|
||||
)
|
||||
return {"loss": loss}
|
||||
|
||||
@paddle.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
token,
|
||||
token_len,
|
||||
prompt_token,
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
streaming,
|
||||
finalize,
|
||||
):
|
||||
assert token.shape[0] == 1
|
||||
embedding = paddle.nn.functional.normalize(x=embedding, axis=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
token, token_len = (
|
||||
paddle.cat([prompt_token, token], dim=1),
|
||||
prompt_token_len + token_len,
|
||||
)
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(paddle.clip(token, min=0)) * mask
|
||||
if finalize is True:
|
||||
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||
else:
|
||||
token, context = (
|
||||
token[:, : -self.pre_lookahead_len],
|
||||
token[:, -self.pre_lookahead_len :],
|
||||
)
|
||||
h, h_lengths = self.encoder(
|
||||
token, token_len, context=context, streaming=streaming
|
||||
)
|
||||
|
||||
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
h = self.encoder_proj(h)
|
||||
conds = paddle.zeros(
|
||||
[1, mel_len1 + mel_len2, self.output_size]
|
||||
).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = paddle.transpose(conds, perm=[0, 2, 1])
|
||||
mask = (~make_pad_mask(paddle.to_tensor([mel_len1 + mel_len2],dtype='int32'))).to(h)
|
||||
feat, _ = self.decoder(
|
||||
mu=paddle.transpose(h, perm=[0, 2, 1]).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), None
|
||||
@ -0,0 +1,342 @@
|
||||
import paddle
|
||||
from abc import ABC
|
||||
from paddlespeech.t2s.models.CosyVoice.common import set_all_random_seed
|
||||
|
||||
class BASECFM(paddle.nn.Layer, ABC):
|
||||
def __init__(self, n_feats, cfm_params, n_spks=1, spk_emb_dim=128):
|
||||
super().__init__()
|
||||
self.n_feats = n_feats
|
||||
self.n_spks = n_spks
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.solver = cfm_params.solver
|
||||
if hasattr(cfm_params, "sigma_min"):
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
else:
|
||||
self.sigma_min = 0.0001
|
||||
self.estimator = None
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
z = paddle.randn(shape=mu.shape, dtype=mu.dtype) * temperature
|
||||
t_span = paddle.linspace(start=0, stop=1, num=n_timesteps + 1)
|
||||
return self.solve_euler(
|
||||
z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
|
||||
)
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
sol = []
|
||||
for step in range(1, len(t_span)):
|
||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
return sol[-1]
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): Target
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): target mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
|
||||
Returns:
|
||||
loss: conditional flow matching loss
|
||||
y: conditional flow
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, _, t = mu.shape
|
||||
t = paddle.rand(shape=[b, 1, 1], dtype=mu.dtype)
|
||||
z = paddle.randn(shape=x1.shape, dtype=x1.dtype)
|
||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||
u = x1 - (1 - self.sigma_min) * z
|
||||
loss = paddle.nn.functional.mse_loss(
|
||||
input=self.estimator(y, mask, mu, t.squeeze(), spks),
|
||||
label=u,
|
||||
reduction="sum",
|
||||
) / (paddle.sum(mask) * u.shape[1])
|
||||
return loss, y
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cfm_params,
|
||||
n_spks=1,
|
||||
spk_emb_dim=64,
|
||||
estimator: paddle.nn.Layer = None,
|
||||
):
|
||||
super().__init__(
|
||||
n_feats=in_channels,
|
||||
cfm_params=cfm_params,
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
)
|
||||
self.t_scheduler = cfm_params.t_scheduler
|
||||
self.training_cfg_rate = cfm_params.training_cfg_rate
|
||||
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
self.estimator = estimator
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
mu,
|
||||
mask,
|
||||
n_timesteps,
|
||||
temperature=1.0,
|
||||
spks=None,
|
||||
cond=None,
|
||||
prompt_len=0,
|
||||
cache=paddle.zeros([1, 80, 0, 2]),
|
||||
):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
z = (
|
||||
paddle.randn(shape=mu.shape, dtype=mu.dtype).to(mu.place).to(mu.dtype)
|
||||
* temperature
|
||||
)
|
||||
cache_size = cache.shape[2]
|
||||
if cache_size != 0:
|
||||
z[:, :, :cache_size] = cache[:, :, :, 0]
|
||||
mu[:, :, :cache_size] = cache[:, :, :, 1]
|
||||
z_cache = paddle.cat([z[:, :, :prompt_len], z[:, :, -34:]], axis=2)
|
||||
mu_cache = paddle.cat([mu[:, :, :prompt_len], mu[:, :, -34:]], axis=2)
|
||||
cache = paddle.stack([z_cache, mu_cache], axis=-1)
|
||||
t_span = paddle.linspace(start=0, stop=1, num=n_timesteps + 1, dtype=mu.dtype)
|
||||
if self.t_scheduler == "cosine":
|
||||
t_span = 1 - paddle.cos(t_span * 0.5 * paddle.pi)
|
||||
return (
|
||||
self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond),
|
||||
cache,
|
||||
)
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
t = t.unsqueeze(axis=0)
|
||||
sol = []
|
||||
x_in = paddle.zeros([2, 80, x.shape[2]], dtype=x.dtype)
|
||||
mask_in = paddle.zeros([2, 1, x.shape[2]], dtype=x.dtype)
|
||||
mu_in = paddle.zeros([2, 80, x.shape[2]], dtype=x.dtype)
|
||||
t_in = paddle.zeros([2], dtype=x.dtype)
|
||||
spks_in = paddle.zeros([2, 80], dtype=x.dtype)
|
||||
cond_in = paddle.zeros([2, 80, x.shape[2]], dtype=x.dtype)
|
||||
for step in range(1, len(t_span)):
|
||||
x_in[:] = x
|
||||
mask_in[:] = mask
|
||||
mu_in[0] = mu
|
||||
t_in[:] = t.unsqueeze(0)
|
||||
spks_in[0] = spks
|
||||
cond_in[0] = cond
|
||||
dphi_dt = self.forward_estimator(
|
||||
x_in, mask_in, mu_in, t_in, spks_in, cond_in, streaming
|
||||
)
|
||||
dphi_dt, cfg_dphi_dt = paddle.split(dphi_dt, [x.shape[0], x.shape[0]], axis=0)
|
||||
dphi_dt = (
|
||||
1.0 + self.inference_cfg_rate
|
||||
) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
return sol[-1].float()
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
|
||||
if isinstance(self.estimator, paddle.nn.Layer):
|
||||
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
||||
else:
|
||||
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
|
||||
paddle.device.current_stream().synchronize()
|
||||
with stream:
|
||||
estimator.set_input_shape("x", (2, 80, x.shape[2]))
|
||||
estimator.set_input_shape("mask", (2, 1, x.shape[2]))
|
||||
estimator.set_input_shape("mu", (2, 80, x.shape[2]))
|
||||
estimator.set_input_shape("t", (2,))
|
||||
estimator.set_input_shape("spks", (2, 80))
|
||||
estimator.set_input_shape("cond", (2, 80, x.shape[2]))
|
||||
data_ptrs = [
|
||||
x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr(),
|
||||
]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
assert (
|
||||
estimator.execute_async_v3(
|
||||
paddle.device.current_stream().cuda_stream
|
||||
)
|
||||
is True
|
||||
)
|
||||
paddle.device.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator, stream)
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): Target
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): target mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
|
||||
Returns:
|
||||
loss: conditional flow matching loss
|
||||
y: conditional flow
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, _, t = mu.shape
|
||||
t = paddle.rand(shape=[b, 1, 1], dtype=mu.dtype)
|
||||
if self.t_scheduler == "cosine":
|
||||
t = 1 - paddle.cos(t * 0.5 * paddle.pi)
|
||||
z = paddle.randn(shape=x1.shape, dtype=x1.dtype)
|
||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||
u = x1 - (1 - self.sigma_min) * z
|
||||
if self.training_cfg_rate > 0:
|
||||
cfg_mask = paddle.rand(shape=b) > self.training_cfg_rate
|
||||
mu = mu * cfg_mask.view(-1, 1, 1)
|
||||
spks = spks * cfg_mask.view(-1, 1)
|
||||
cond = cond * cfg_mask.view(-1, 1, 1)
|
||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
||||
loss = paddle.nn.functional.mse_loss(
|
||||
input=pred * mask, label=u * mask, reduction="sum"
|
||||
) / (paddle.sum(mask) * u.shape[1])
|
||||
return loss, y
|
||||
|
||||
|
||||
class CausalConditionalCFM(ConditionalCFM):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cfm_params,
|
||||
n_spks=1,
|
||||
spk_emb_dim=64,
|
||||
estimator: paddle.nn.Layer = None,
|
||||
):
|
||||
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
||||
set_all_random_seed(42)
|
||||
self.rand_noise = paddle.randn([1, 80, 50 * 300])
|
||||
@paddle.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
mu,
|
||||
mask,
|
||||
n_timesteps,
|
||||
temperature=1.0,
|
||||
spks=None,
|
||||
cond=None,
|
||||
streaming=False,
|
||||
):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
z = self.rand_noise[:, :, : mu.shape[2]].to(mu.place).to(mu.dtype) * temperature
|
||||
t_span = paddle.linspace(start=0, stop=1, num=n_timesteps + 1, dtype=mu.dtype)
|
||||
if self.t_scheduler == "cosine":
|
||||
t_span = 1 - paddle.cos(t_span * 0.5 * paddle.pi)
|
||||
|
||||
return (
|
||||
self.solve_euler(
|
||||
z,
|
||||
t_span=t_span,
|
||||
mu=mu,
|
||||
mask=mask,
|
||||
spks=spks,
|
||||
cond=cond,
|
||||
streaming=streaming,
|
||||
),
|
||||
None,
|
||||
)
|
||||
@ -0,0 +1,124 @@
|
||||
from abc import ABC
|
||||
|
||||
import paddle
|
||||
from matcha.models.components.decoder import Decoder
|
||||
from matcha.utils.pylogger import get_pylogger
|
||||
|
||||
log = get_pylogger(__name__)
|
||||
|
||||
|
||||
class BASECFM(paddle.nn.Layer, ABC):
|
||||
def __init__(self, n_feats, cfm_params, n_spks=1, spk_emb_dim=128):
|
||||
super().__init__()
|
||||
self.n_feats = n_feats
|
||||
self.n_spks = n_spks
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.solver = cfm_params.solver
|
||||
if hasattr(cfm_params, "sigma_min"):
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
else:
|
||||
self.sigma_min = 0.0001
|
||||
self.estimator = None
|
||||
|
||||
@paddle.no_grad()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
z = paddle.randn(shape=mu.shape, dtype=mu.dtype) * temperature
|
||||
t_span = paddle.linspace(start=0, stop=1, num=n_timesteps + 1)
|
||||
return self.solve_euler(
|
||||
z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond
|
||||
)
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
sol = []
|
||||
for step in range(1, len(t_span)):
|
||||
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
return sol[-1]
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): Target
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): target mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
|
||||
Returns:
|
||||
loss: conditional flow matching loss
|
||||
y: conditional flow
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, _, t = mu.shape
|
||||
t = paddle.rand(shape=[b, 1, 1], dtype=mu.dtype)
|
||||
z = paddle.randn(shape=x1.shape, dtype=x1.dtype)
|
||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||
u = x1 - (1 - self.sigma_min) * z
|
||||
loss = paddle.nn.functional.mse_loss(
|
||||
input=self.estimator(y, mask, mu, t.squeeze(), spks),
|
||||
label=u,
|
||||
reduction="sum",
|
||||
) / (paddle.sum(mask) * u.shape[1])
|
||||
return loss, y
|
||||
|
||||
|
||||
class CFM(BASECFM):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channel,
|
||||
cfm_params,
|
||||
decoder_params,
|
||||
n_spks=1,
|
||||
spk_emb_dim=64,
|
||||
):
|
||||
super().__init__(
|
||||
n_feats=in_channels,
|
||||
cfm_params=cfm_params,
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
)
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
|
||||
self.estimator = Decoder(
|
||||
in_channels=in_channels, out_channels=out_channel, **decoder_params
|
||||
)
|
||||
@ -0,0 +1,123 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import paddle
|
||||
|
||||
|
||||
class LoRALinearLayer(paddle.nn.Layer):
|
||||
"""
|
||||
A linear layer that is used with LoRA.
|
||||
|
||||
Parameters:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
rank (`int`, `optional`, defaults to 4):
|
||||
The rank of the LoRA layer.
|
||||
network_alpha (`float`, `optional`, defaults to `None`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the same
|
||||
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
|
||||
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
device (`torch.device`, `optional`, defaults to `None`):
|
||||
The device to use for the layer's weights.
|
||||
dtype (`torch.dtype`, `optional`, defaults to `None`):
|
||||
The dtype to use for the layer's weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 4,
|
||||
network_alpha: Optional[float] = None,
|
||||
device: Optional[Union[paddle.CPUPlace, paddle.CUDAPlace, str]] = None,
|
||||
dtype: Optional[paddle.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.down = paddle.nn.Linear(
|
||||
in_features=in_features, out_features=rank, bias_attr=False
|
||||
)
|
||||
self.up = paddle.nn.Linear(
|
||||
in_features=rank, out_features=out_features, bias_attr=False
|
||||
)
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
paddle.nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
paddle.nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
down_hidden_states = self.down(hidden_states.to(dtype))
|
||||
up_hidden_states = self.up(down_hidden_states)
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
|
||||
|
||||
class LoRACompatibleLinear(paddle.nn.Linear):
|
||||
"""
|
||||
A Linear layer that can be used with LoRA.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
dtype, device = self.weight.data.dtype, self.weight.data.place
|
||||
w_orig = self.weight.data.float()
|
||||
w_up = self.lora_layer.up.weight.data.float()
|
||||
w_down = self.lora_layer.down.weight.data.float()
|
||||
if self.lora_layer.network_alpha is not None:
|
||||
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
||||
fused_weight = (
|
||||
w_orig + lora_scale * paddle.bmm(w_up[None, :], w_down[None, :])[0]
|
||||
)
|
||||
if safe_fusing and paddle.isnan(fused_weight).any().item():
|
||||
raise ValueError(
|
||||
f"This LoRA weight seems to be broken. Encountered NaN values when trying to fuse LoRA weights for {self}.LoRA weights will not be fused."
|
||||
)
|
||||
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
self.lora_layer = None
|
||||
self.w_up = w_up.cpu()
|
||||
self.w_down = w_down.cpu()
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (
|
||||
getattr(self, "w_up", None) is not None
|
||||
and getattr(self, "w_down", None) is not None
|
||||
):
|
||||
return
|
||||
fused_weight = self.weight.data
|
||||
dtype, device = fused_weight.dtype, fused_weight.place
|
||||
w_up = self.w_up.to(device=device).float()
|
||||
w_down = self.w_down.to(device).float()
|
||||
unfused_weight = (
|
||||
fused_weight.float()
|
||||
- self._lora_scale * paddle.bmm(w_up[None, :], w_down[None, :])[0]
|
||||
)
|
||||
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(
|
||||
self, hidden_states: paddle.Tensor, scale: float = 1.0
|
||||
) -> paddle.Tensor:
|
||||
if self.lora_layer is None:
|
||||
out = super().forward(hidden_states)
|
||||
return out
|
||||
else:
|
||||
out = super().forward(hidden_states) + scale * self.lora_layer(
|
||||
hidden_states
|
||||
)
|
||||
return out
|
||||
@ -0,0 +1,287 @@
|
||||
import paddle
|
||||
|
||||
def device2str(type=None, index=None, *, device=None):
|
||||
type = device if device else type
|
||||
if isinstance(type, int):
|
||||
type = f'gpu:{type}'
|
||||
elif isinstance(type, str):
|
||||
if 'cuda' in type:
|
||||
type = type.replace('cuda', 'gpu')
|
||||
if 'cpu' in type:
|
||||
type = 'cpu'
|
||||
elif index is not None:
|
||||
type = f'{type}:{index}'
|
||||
elif isinstance(type, paddle.CPUPlace) or (type is None):
|
||||
type = 'cpu'
|
||||
elif isinstance(type, paddle.CUDAPlace):
|
||||
type = f'gpu:{type.get_device_id()}'
|
||||
|
||||
return type
|
||||
|
||||
def _Tensor_max(self, *args, **kwargs):
|
||||
if "other" in kwargs:
|
||||
kwargs["y"] = kwargs.pop("other")
|
||||
ret = paddle.maximum(self, *args, **kwargs)
|
||||
elif len(args) == 1 and isinstance(args[0], paddle.Tensor):
|
||||
ret = paddle.maximum(self, *args, **kwargs)
|
||||
else:
|
||||
if "dim" in kwargs:
|
||||
kwargs["axis"] = kwargs.pop("dim")
|
||||
|
||||
if "axis" in kwargs or len(args) >= 1:
|
||||
ret = paddle.max(self, *args, **kwargs), paddle.argmax(self, *args, **kwargs)
|
||||
else:
|
||||
ret = paddle.max(self, *args, **kwargs)
|
||||
|
||||
return ret
|
||||
|
||||
setattr(paddle.Tensor, "_max", _Tensor_max)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
def subsequent_mask(
|
||||
size: int,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
""\"Create mask for subsequent steps (size, size).
|
||||
|
||||
This mask is used only in decoder which works in an auto-regressive mode.
|
||||
This means the current step could only do attention with its left steps.
|
||||
|
||||
In encoder, fully attention is used when streaming is not necessary and
|
||||
the sequence is not long. In this case, no attention mask is needed.
|
||||
|
||||
When streaming is need, chunk-based attention is used in encoder. See
|
||||
subsequent_chunk_mask for the chunk-based attention mask.
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
||||
dtype (torch.device): result dtype
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
""\"
|
||||
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
||||
return torch.tril(ret)
|
||||
"""
|
||||
|
||||
|
||||
def subsequent_mask(
|
||||
>>>>>> size: int, device: torch.device = device2str("cpu")
|
||||
) -> paddle.Tensor:
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
This mask is used only in decoder which works in an auto-regressive mode.
|
||||
This means the current step could only do attention with its left steps.
|
||||
|
||||
In encoder, fully attention is used when streaming isnot necessary and
|
||||
the sequence is not long. In this case, no attention mask is needed.
|
||||
|
||||
When streaming is need, chunk-based attention is used in encoder. See
|
||||
subsequent_chunk_mask for the chunk-based attention mask.
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
||||
dtype (torch.device): result dtype
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
arange = paddle.arange(size, device=device)
|
||||
mask = arange.expand(size, size)
|
||||
arange = arange.unsqueeze(-1)
|
||||
mask = mask <= arange
|
||||
return mask
|
||||
|
||||
|
||||
def subsequent_chunk_mask_deprecated(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
>>>>>> device: torch.device = device2str("cpu"),
|
||||
) -> paddle.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
ret = paddle.zeros(size, size, device=device, dtype=paddle.bool)
|
||||
for i in range(size):
|
||||
if num_left_chunks < 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
||||
ending = min((i // chunk_size + 1) * chunk_size, size)
|
||||
ret[i, start:ending] = True
|
||||
return ret
|
||||
|
||||
|
||||
def subsequent_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
>>>>>> device: torch.device = device2str("cpu"),
|
||||
) -> paddle.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
pos_idx = paddle.arange(size, device=device)
|
||||
block_value = (
|
||||
paddle.div(pos_idx, chunk_size, rounding_mode="trunc") + 1
|
||||
) * chunk_size
|
||||
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
||||
return ret
|
||||
|
||||
|
||||
def add_optional_chunk_mask(
|
||||
xs: paddle.Tensor,
|
||||
masks: paddle.Tensor,
|
||||
use_dynamic_chunk: bool,
|
||||
use_dynamic_left_chunk: bool,
|
||||
decoding_chunk_size: int,
|
||||
static_chunk_size: int,
|
||||
num_decoding_left_chunks: int,
|
||||
enable_full_context: bool = True,
|
||||
):
|
||||
"""Apply optional mask for encoder.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
||||
mask (torch.Tensor): mask for xs, (B, 1, L)
|
||||
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
||||
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
||||
training.
|
||||
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
static_chunk_size (int): chunk size for static chunk training/decoding
|
||||
if it's greater than 0, if use_dynamic_chunk is true,
|
||||
this parameter will be ignored
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
enable_full_context (bool):
|
||||
True: chunk size is either [1, 25] or full context(max_len)
|
||||
False: chunk size ~ U[1, 25]
|
||||
|
||||
Returns:
|
||||
torch.Tensor: chunk mask of the input xs.
|
||||
"""
|
||||
if use_dynamic_chunk:
|
||||
max_len = xs.size(1)
|
||||
if decoding_chunk_size < 0:
|
||||
chunk_size = max_len
|
||||
num_left_chunks = -1
|
||||
elif decoding_chunk_size > 0:
|
||||
chunk_size = decoding_chunk_size
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
else:
|
||||
chunk_size = paddle.randint(low=1, high=max_len, shape=(1,)).item()
|
||||
num_left_chunks = -1
|
||||
if chunk_size > max_len // 2 and enable_full_context:
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % 25 + 1
|
||||
if use_dynamic_left_chunk:
|
||||
max_left_chunks = (max_len - 1) // chunk_size
|
||||
num_left_chunks = paddle.randint(
|
||||
low=0, high=max_left_chunks, shape=(1,)
|
||||
).item()
|
||||
chunk_masks = subsequent_chunk_mask(
|
||||
xs.size(1), chunk_size, num_left_chunks, xs.place
|
||||
)
|
||||
chunk_masks = chunk_masks.unsqueeze(0)
|
||||
chunk_masks = masks & chunk_masks
|
||||
elif static_chunk_size > 0:
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
chunk_masks = subsequent_chunk_mask(
|
||||
xs.size(1), static_chunk_size, num_left_chunks, xs.place
|
||||
)
|
||||
chunk_masks = chunk_masks.unsqueeze(0)
|
||||
chunk_masks = masks & chunk_masks
|
||||
else:
|
||||
chunk_masks = masks
|
||||
assert chunk_masks.dtype == paddle.bool
|
||||
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
||||
print(
|
||||
"get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!"
|
||||
)
|
||||
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
||||
return chunk_masks
|
||||
|
||||
|
||||
def make_pad_mask(lengths: paddle.Tensor, max_len: int = 0) -> paddle.Tensor:
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
See description of make_non_pad_mask.
|
||||
|
||||
Args:
|
||||
lengths (torch.Tensor): Batch of lengths (B,).
|
||||
Returns:
|
||||
torch.Tensor: Mask tensor containing indices of padded part.
|
||||
|
||||
Examples:
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
"""
|
||||
batch_size = lengths.size(0)
|
||||
max_len = max_len if max_len > 0 else lengths._max().item()
|
||||
seq_range = paddle.arange(0, max_len, dtype=paddle.int64, device=lengths.place)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
seq_length_expand = lengths.unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
return mask
|
||||
@ -0,0 +1,445 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import einops
|
||||
import paddle
|
||||
from conformer import ConformerBlock
|
||||
from matcha.models.components.transformer import BasicTransformerBlock
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
"swish": paddle.nn.SiLU(),
|
||||
"silu": paddle.nn.SiLU(),
|
||||
"mish": paddle.nn.Mish(),
|
||||
"gelu": paddle.nn.GELU(),
|
||||
"relu": paddle.nn.ReLU(),
|
||||
}
|
||||
def get_activation(act_fn: str) -> paddle.nn.Layer:
|
||||
"""Helper function to get activation function from string.
|
||||
|
||||
Args:
|
||||
act_fn (str): Name of activation function.
|
||||
|
||||
Returns:
|
||||
nn.Module: Activation function.
|
||||
"""
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in ACTIVATION_FUNCTIONS:
|
||||
return ACTIVATION_FUNCTIONS[act_fn]
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
class SinusoidalPosEmb(paddle.nn.Layer):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.place
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = paddle.exp(x=paddle.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = paddle.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class Block1D(paddle.nn.Layer):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super().__init__()
|
||||
self.block = paddle.nn.Sequential(
|
||||
paddle.nn.Conv1d(dim, dim_out, 3, padding=1),
|
||||
paddle.nn.GroupNorm(num_groups=groups, num_channels=dim_out),
|
||||
paddle.nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x, mask):
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
class ResnetBlock1D(paddle.nn.Layer):
|
||||
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
||||
super().__init__()
|
||||
self.mlp = paddle.nn.Sequential(
|
||||
paddle.nn.Mish(),
|
||||
paddle.nn.Linear(in_features=time_emb_dim, out_features=dim_out),
|
||||
)
|
||||
self.block1 = Block1D(dim, dim_out, groups=groups)
|
||||
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
||||
self.res_conv = paddle.nn.Conv1d(dim, dim_out, 1)
|
||||
|
||||
def forward(self, x, mask, time_emb):
|
||||
h = self.block1(x, mask)
|
||||
h += self.mlp(time_emb).unsqueeze(-1)
|
||||
h = self.block2(h, mask)
|
||||
output = h + self.res_conv(x * mask)
|
||||
return output
|
||||
|
||||
|
||||
class Downsample1D(paddle.nn.Layer):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = paddle.nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class TimestepEmbedding(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.linear_1 = paddle.nn.Linear(
|
||||
in_features=in_channels, out_features=time_embed_dim
|
||||
)
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = paddle.nn.Linear(
|
||||
in_features=cond_proj_dim, out_features=in_channels, bias_attr=False
|
||||
)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
self.act = get_activation(act_fn)
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = paddle.nn.Linear(
|
||||
in_features=time_embed_dim, out_features=time_embed_dim_out
|
||||
)
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class Upsample1D(paddle.nn.Layer):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
use_conv=False,
|
||||
use_conv_transpose=True,
|
||||
out_channels=None,
|
||||
name="conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = paddle.nn.Conv1DTranspose(
|
||||
in_channels=channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
)
|
||||
elif use_conv:
|
||||
self.conv = paddle.nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
outputs = paddle.nn.functional.interpolate(
|
||||
x=inputs, scale_factor=2.0, mode="nearest"
|
||||
)
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class ConformerWrapper(ConformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head=64,
|
||||
heads=8,
|
||||
ff_mult=4,
|
||||
conv_expansion_factor=2,
|
||||
conv_kernel_size=31,
|
||||
attn_dropout=0,
|
||||
ff_dropout=0,
|
||||
conv_dropout=0,
|
||||
conv_causal=False,
|
||||
):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
ff_mult=ff_mult,
|
||||
conv_expansion_factor=conv_expansion_factor,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
attn_dropout=attn_dropout,
|
||||
ff_dropout=ff_dropout,
|
||||
conv_dropout=conv_dropout,
|
||||
conv_causal=conv_causal,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
timestep=None,
|
||||
):
|
||||
return super().forward(x=hidden_states, mask=attention_mask.bool())
|
||||
|
||||
|
||||
class Decoder(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
down_block_type="transformer",
|
||||
mid_block_type="transformer",
|
||||
up_block_type="transformer",
|
||||
):
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=time_embed_dim, act_fn="silu"
|
||||
)
|
||||
self.down_blocks = paddle.nn.LayerList(sublayers=[])
|
||||
self.mid_blocks = paddle.nn.LayerList(sublayers=[])
|
||||
self.up_blocks = paddle.nn.LayerList(sublayers=[])
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
|
||||
)
|
||||
transformer_blocks = paddle.nn.LayerList(
|
||||
sublayers=[
|
||||
self.get_block(
|
||||
down_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel)
|
||||
if not is_last
|
||||
else paddle.nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
self.down_blocks.append(
|
||||
paddle.nn.LayerList(sublayers=[resnet, transformer_blocks, downsample])
|
||||
)
|
||||
for i in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim
|
||||
)
|
||||
transformer_blocks = paddle.nn.LayerList(
|
||||
sublayers=[
|
||||
self.get_block(
|
||||
mid_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
self.mid_blocks.append(
|
||||
paddle.nn.LayerList(sublayers=[resnet, transformer_blocks])
|
||||
)
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i]
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
resnet = ResnetBlock1D(
|
||||
dim=2 * input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = paddle.nn.LayerList(
|
||||
sublayers=[
|
||||
self.get_block(
|
||||
up_block_type,
|
||||
output_channel,
|
||||
attention_head_dim,
|
||||
num_heads,
|
||||
dropout,
|
||||
act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True)
|
||||
if not is_last
|
||||
else paddle.nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
self.up_blocks.append(
|
||||
paddle.nn.LayerList(sublayers=[resnet, transformer_blocks, upsample])
|
||||
)
|
||||
self.final_block = Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = paddle.nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
self.initialize_weights()
|
||||
|
||||
@staticmethod
|
||||
def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
|
||||
if block_type == "conformer":
|
||||
block = ConformerWrapper(
|
||||
dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_heads,
|
||||
ff_mult=1,
|
||||
conv_expansion_factor=2,
|
||||
ff_dropout=dropout,
|
||||
attn_dropout=dropout,
|
||||
conv_dropout=dropout,
|
||||
conv_kernel_size=31,
|
||||
)
|
||||
elif block_type == "transformer":
|
||||
block = BasicTransformerBlock(
|
||||
dim=dim,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown block type {block_type}")
|
||||
return block
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.sublayers():
|
||||
if isinstance(m, paddle.nn.Conv1d):
|
||||
paddle.nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
paddle.nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, paddle.nn.GroupNorm):
|
||||
paddle.nn.init.constant_(m.weight, 1)
|
||||
paddle.nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, paddle.nn.Linear):
|
||||
paddle.nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
paddle.nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
t = self.time_embeddings(t)
|
||||
t = self.time_mlp(t)
|
||||
x = einops.pack([x, mu], "b * t")[0]
|
||||
if spks is not None:
|
||||
spks = einops.repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = einops.pack([x, spks], "b * t")[0]
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x = resnet(x, mask_down, t)
|
||||
x = einops.rearrange(x, "b c t -> b t c")
|
||||
mask_down = einops.rearrange(mask_down, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x, attention_mask=mask_down, timestep=t
|
||||
)
|
||||
x = einops.rearrange(x, "b t c -> b c t")
|
||||
mask_down = einops.rearrange(mask_down, "b t -> b 1 t")
|
||||
hiddens.append(x)
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = einops.rearrange(x, "b c t -> b t c")
|
||||
mask_mid = einops.rearrange(mask_mid, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x, attention_mask=mask_mid, timestep=t
|
||||
)
|
||||
x = einops.rearrange(x, "b t c -> b c t")
|
||||
mask_mid = einops.rearrange(mask_mid, "b t -> b 1 t")
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
x = resnet(einops.pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
|
||||
x = einops.rearrange(x, "b c t -> b t c")
|
||||
mask_up = einops.rearrange(mask_up, "b 1 t -> b t")
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x, attention_mask=mask_up, timestep=t
|
||||
)
|
||||
x = einops.rearrange(x, "b t c -> b c t")
|
||||
mask_up = einops.rearrange(mask_up, "b t -> b 1 t")
|
||||
x = upsample(x * mask_up)
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
return output * mask
|
||||
@ -0,0 +1,276 @@
|
||||
import numbers
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from .activations import get_activation
|
||||
from .embeddings import (CombinedTimestepLabelEmbeddings,
|
||||
PixArtAlphaCombinedTimestepSizeEmbeddings)
|
||||
def get_activation(act_fn):
|
||||
if act_fn == "silu":
|
||||
return nn.Silu()
|
||||
elif act_fn == "mish":
|
||||
return nn.Mish()
|
||||
elif act_fn == "relu":
|
||||
return nn.ReLU()
|
||||
elif act_fn == "gelu":
|
||||
return nn.GELU()
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
class AdaLayerNorm(paddle.nn.Layer):
|
||||
"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, num_embeddings: int):
|
||||
super().__init__()
|
||||
self.emb = paddle.nn.Embedding(num_embeddings, embedding_dim)
|
||||
self.silu = paddle.nn.SiLU()
|
||||
self.linear = paddle.nn.Linear(
|
||||
in_features=embedding_dim, out_features=embedding_dim * 2
|
||||
)
|
||||
self.norm = paddle.nn.LayerNorm(
|
||||
normalized_shape=embedding_dim, weight_attr=False, bias_attr=False
|
||||
)
|
||||
|
||||
def forward(self, x: paddle.Tensor, timestep: paddle.Tensor) -> paddle.Tensor:
|
||||
emb = self.linear(self.silu(self.emb(timestep)))
|
||||
scale, shift = paddle.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AdaLayerNormZero(paddle.nn.Layer):
|
||||
"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
|
||||
super().__init__()
|
||||
if num_embeddings is not None:
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
self.silu = paddle.nn.SiLU()
|
||||
self.linear = paddle.nn.Linear(
|
||||
in_features=embedding_dim, out_features=6 * embedding_dim, bias_attr=True
|
||||
)
|
||||
self.norm = paddle.nn.LayerNorm(
|
||||
normalized_shape=embedding_dim,
|
||||
weight_attr=False,
|
||||
bias_attr=False,
|
||||
epsilon=1e-06,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
timestep: Optional[paddle.Tensor] = None,
|
||||
class_labels: Optional[paddle.LongTensor] = None,
|
||||
hidden_dtype: Optional[paddle.dtype] = None,
|
||||
emb: Optional[paddle.Tensor] = None,
|
||||
) -> Tuple[
|
||||
paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor
|
||||
]:
|
||||
if self.emb is not None:
|
||||
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(
|
||||
6, dim=1
|
||||
)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class AdaLayerNormSingle(paddle.nn.Layer):
|
||||
"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
|
||||
super().__init__()
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim,
|
||||
size_emb_dim=embedding_dim // 3,
|
||||
use_additional_conditions=use_additional_conditions,
|
||||
)
|
||||
self.silu = paddle.nn.SiLU()
|
||||
self.linear = paddle.nn.Linear(
|
||||
in_features=embedding_dim, out_features=6 * embedding_dim, bias_attr=True
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: paddle.Tensor,
|
||||
added_cond_kwargs: Optional[Dict[str, paddle.Tensor]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
hidden_dtype: Optional[paddle.dtype] = None,
|
||||
) -> Tuple[
|
||||
paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor
|
||||
]:
|
||||
embedded_timestep = self.emb(
|
||||
timestep,
|
||||
**added_cond_kwargs,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
||||
|
||||
|
||||
class AdaGroupNorm(paddle.nn.Layer):
|
||||
"""
|
||||
GroupNorm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
num_groups (`int`): The number of groups to separate the channels into.
|
||||
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
|
||||
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
out_dim: int,
|
||||
num_groups: int,
|
||||
act_fn: Optional[str] = None,
|
||||
eps: float = 1e-05,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_groups = num_groups
|
||||
self.eps = eps
|
||||
if act_fn is None:
|
||||
self.act = None
|
||||
else:
|
||||
self.act = get_activation(act_fn)
|
||||
self.linear = paddle.nn.Linear(
|
||||
in_features=embedding_dim, out_features=out_dim * 2
|
||||
)
|
||||
|
||||
def forward(self, x: paddle.Tensor, emb: paddle.Tensor) -> paddle.Tensor:
|
||||
if self.act:
|
||||
emb = self.act(emb)
|
||||
emb = self.linear(emb)
|
||||
emb = emb[:, :, None, None]
|
||||
scale, shift = emb.chunk(2, dim=1)
|
||||
x = paddle.nn.functional.group_norm(
|
||||
x=x, num_groups=self.num_groups, epsilon=self.eps
|
||||
)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AdaLayerNormContinuous(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
elementwise_affine=True,
|
||||
eps=1e-05,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.silu = paddle.nn.SiLU()
|
||||
self.linear = paddle.nn.Linear(
|
||||
in_features=conditioning_embedding_dim,
|
||||
out_features=embedding_dim * 2,
|
||||
bias_attr=bias,
|
||||
)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(
|
||||
self, x: paddle.Tensor, conditioning_embedding: paddle.Tensor
|
||||
) -> paddle.Tensor:
|
||||
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
||||
scale, shift = paddle.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
LayerNorm = paddle.nn.LayerNorm
|
||||
|
||||
|
||||
class LayerNorm(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
eps: float = 1e-05,
|
||||
elementwise_affine: bool = True,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
self.dim = paddle.Size(dim)
|
||||
if elementwise_affine:
|
||||
self.weight = paddle.nn.parameter.Parameter(paddle.ones(dim))
|
||||
self.bias = (
|
||||
paddle.nn.parameter.Parameter(paddle.zeros(dim)) if bias else None
|
||||
)
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
return paddle.nn.functional.layer_norm(
|
||||
input, self.dim, self.weight, self.bias, self.eps
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(paddle.nn.Layer):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
self.dim = paddle.Size(dim)
|
||||
if elementwise_affine:
|
||||
self.weight = paddle.nn.parameter.Parameter(paddle.ones(dim))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(paddle.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * paddle.rsqrt(variance + self.eps)
|
||||
if self.weight is not None:
|
||||
if self.weight.dtype in [paddle.float16, paddle.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GlobalResponseNorm(paddle.nn.Layer):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = paddle.nn.parameter.Parameter(paddle.zeros(1, 1, 1, dim))
|
||||
self.beta = paddle.nn.parameter.Parameter(paddle.zeros(1, 1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
gx = paddle.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-06)
|
||||
return self.gamma * (x * nx) + self.beta + x
|
||||
@ -0,0 +1,241 @@
|
||||
import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from paddlenlp.transformers import AutoTokenizer
|
||||
import paddle
|
||||
import tiktoken
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
"minnan": "minnan",
|
||||
"wuyu": "wuyu",
|
||||
"dialect": "dialect",
|
||||
"zh/en": "zh/en",
|
||||
"en/zh": "en/zh",
|
||||
}
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
"mandarin": "zh",
|
||||
}
|
||||
AUDIO_EVENT = {
|
||||
"ASR": "ASR",
|
||||
"AED": "AED",
|
||||
"SER": "SER",
|
||||
"Speech": "Speech",
|
||||
"/Speech": "/Speech",
|
||||
"BGM": "BGM",
|
||||
"/BGM": "/BGM",
|
||||
"Laughter": "Laughter",
|
||||
"/Laughter": "/Laughter",
|
||||
"Applause": "Applause",
|
||||
"/Applause": "/Applause",
|
||||
}
|
||||
EMOTION = {"HAPPY": "HAPPY", "SAD": "SAD", "ANGRY": "ANGRY", "NEUTRAL": "NEUTRAL"}
|
||||
TTS_Vocal_Token = {
|
||||
"TTS/B": "TTS/B",
|
||||
"TTS/O": "TTS/O",
|
||||
"TTS/Q": "TTS/Q",
|
||||
"TTS/A": "TTS/A",
|
||||
"TTS/CO": "TTS/CO",
|
||||
"TTS/CL": "TTS/CL",
|
||||
"TTS/H": "TTS/H",
|
||||
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)},
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
special_tokens = {}
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
||||
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)],
|
||||
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())],
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
for token in specials:
|
||||
special_tokens[token] = n_vocab
|
||||
n_vocab += 1
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str="'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
class QwenTokenizer:
|
||||
def __init__(self, skip_special_tokens=True):
|
||||
super().__init__()
|
||||
special_tokens = {
|
||||
"eos_token": "<|endoftext|>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
"additional_special_tokens": [
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|endofprompt|>",
|
||||
"[breath]",
|
||||
"<strong>",
|
||||
"</strong>",
|
||||
"[noise]",
|
||||
"[laughter]",
|
||||
"[cough]",
|
||||
"[clucking]",
|
||||
"[accent]",
|
||||
"[quick_breath]",
|
||||
"<laughter>",
|
||||
"</laughter>",
|
||||
"[hissing]",
|
||||
"[sigh]",
|
||||
"[vocalized-noise]",
|
||||
"[lipsmack]",
|
||||
"[mn]",
|
||||
],
|
||||
}
|
||||
self.special_tokens = special_tokens
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
|
||||
self.tokenizer.add_special_tokens(special_tokens)
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
tokens = self.tokenizer([text], return_tensors="pd")
|
||||
tokens = tokens["input_ids"][0].cpu().tolist()
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
tokens = paddle.tensor(tokens, dtype=paddle.int64)
|
||||
text = self.tokenizer.batch_decode(
|
||||
[tokens], skip_special_tokens=self.skip_special_tokens
|
||||
)[0]
|
||||
return text
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_qwen_tokenizer(skip_special_tokens: bool) -> QwenTokenizer:
|
||||
return QwenTokenizer(skip_special_tokens=skip_special_tokens)
|
||||
@ -0,0 +1,113 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
|
||||
"""ConvolutionModule definition."""
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class ConvolutionModule(paddle.nn.Layer):
|
||||
"""ConvolutionModule in Conformer model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 15,
|
||||
activation: paddle.nn.Layer = paddle.nn.ReLU(),
|
||||
norm: str = "batch_norm",
|
||||
causal: bool = False,
|
||||
bias: bool = True,
|
||||
):
|
||||
"""Construct an ConvolutionModule object.
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernel size of conv layers.
|
||||
causal (int): Whether use causal convolution or not
|
||||
"""
|
||||
super().__init__()
|
||||
self.pointwise_conv1 = paddle.nn.Conv1d(
|
||||
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias
|
||||
)
|
||||
if causal:
|
||||
padding = 0
|
||||
self.lorder = kernel_size - 1
|
||||
else:
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.lorder = 0
|
||||
self.depthwise_conv = paddle.nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
assert norm in ["batch_norm", "layer_norm"]
|
||||
if norm == "batch_norm":
|
||||
self.use_layer_norm = False
|
||||
self.norm = paddle.nn.BatchNorm1D(num_features=channels)
|
||||
else:
|
||||
self.use_layer_norm = True
|
||||
self.norm = paddle.nn.LayerNorm(normalized_shape=channels)
|
||||
self.pointwise_conv2 = paddle.nn.Conv1d(
|
||||
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool),
|
||||
cache: paddle.Tensor = paddle.zeros((0, 0, 0)),
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Compute convolution module.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, channels).
|
||||
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
||||
(0, 0, 0) means fake mask.
|
||||
cache (torch.Tensor): left context cache, it is only
|
||||
used in causal convolution (#batch, channels, cache_t),
|
||||
(0, 0, 0) meas fake cache.
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, channels).
|
||||
"""
|
||||
x = x.transpose(1, 2)
|
||||
if mask_pad.size(2) > 0:
|
||||
x.masked_fill_(~mask_pad, 0.0)
|
||||
if self.lorder > 0:
|
||||
if cache.size(2) == 0:
|
||||
x = paddle.compat.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||
else:
|
||||
assert cache.size(0) == x.size(0)
|
||||
assert cache.size(1) == x.size(1)
|
||||
x = paddle.cat((cache, x), dim=2)
|
||||
assert x.size(2) > self.lorder
|
||||
new_cache = x[:, :, -self.lorder :]
|
||||
else:
|
||||
new_cache = paddle.zeros((0, 0, 0), dtype=x.dtype, device=x.place)
|
||||
x = self.pointwise_conv1(x)
|
||||
x = paddle.nn.functional.glu(x=x, axis=1)
|
||||
x = self.depthwise_conv(x)
|
||||
if self.use_layer_norm:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.activation(self.norm(x))
|
||||
if self.use_layer_norm:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.pointwise_conv2(x)
|
||||
if mask_pad.size(2) > 0:
|
||||
x.masked_fill_(~mask_pad, 0.0)
|
||||
return x.transpose(1, 2), new_cache
|
||||
@ -0,0 +1,275 @@
|
||||
import paddle
|
||||
|
||||
"""Positonal Encoding Module."""
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PositionalEncoding(paddle.nn.Layer):
|
||||
"""Positional encoding.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
||||
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
dropout_rate: float,
|
||||
max_len: int = 5000,
|
||||
reverse: bool = False,
|
||||
):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = paddle.nn.Dropout(p=dropout_rate)
|
||||
self.max_len = max_len
|
||||
self.pe = paddle.zeros(self.max_len, self.d_model)
|
||||
position = paddle.arange(0, self.max_len, dtype=paddle.float32).unsqueeze(1)
|
||||
div_term = paddle.exp(
|
||||
x=paddle.arange(0, self.d_model, 2, dtype=paddle.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
self.pe[:, 0::2] = paddle.sin(position * div_term)
|
||||
self.pe[:, 1::2] = paddle.cos(position * div_term)
|
||||
self.pe = self.pe.unsqueeze(0)
|
||||
|
||||
def forward(
|
||||
self, x: paddle.Tensor, offset: Union[int, paddle.Tensor] = 0
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
offset (int, torch.tensor): position offset
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
torch.Tensor: for compatibility to RelPositionalEncoding
|
||||
"""
|
||||
self.pe = self.pe.to(x.place)
|
||||
pos_emb = self.position_encoding(offset, x.size(1), False)
|
||||
x = x * self.xscale + pos_emb
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
def position_encoding(
|
||||
self, offset: Union[int, paddle.Tensor], size: int, apply_dropout: bool = True
|
||||
) -> paddle.Tensor:
|
||||
"""For getting encoding in a streaming fashion
|
||||
|
||||
Attention!!!!!
|
||||
we apply dropout only once at the whole utterance level in a none
|
||||
streaming way, but will call this function several times with
|
||||
increasing input size in a streaming scenario, so the dropout will
|
||||
be applied several times.
|
||||
|
||||
Args:
|
||||
offset (int or torch.tensor): start offset
|
||||
size (int): required size of position encoding
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Corresponding encoding
|
||||
"""
|
||||
if isinstance(offset, int):
|
||||
assert offset + size <= self.max_len
|
||||
pos_emb = self.pe[:, offset : offset + size]
|
||||
elif isinstance(offset, paddle.Tensor) and offset.dim() == 0:
|
||||
assert offset + size <= self.max_len
|
||||
pos_emb = self.pe[:, offset : offset + size]
|
||||
else:
|
||||
assert paddle.compat.max(offset) + size <= self.max_len
|
||||
index = offset.unsqueeze(1) + paddle.arange(0, size).to(offset.place)
|
||||
flag = index > 0
|
||||
index = index * flag
|
||||
pos_emb = paddle.nn.functional.embedding(index, self.pe[0])
|
||||
if apply_dropout:
|
||||
pos_emb = self.dropout(pos_emb)
|
||||
return pos_emb
|
||||
|
||||
|
||||
class RelPositionalEncoding(PositionalEncoding):
|
||||
"""Relative positional encoding module.
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
||||
"""Initialize class."""
|
||||
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
||||
|
||||
def forward(
|
||||
self, x: paddle.Tensor, offset: Union[int, paddle.Tensor] = 0
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Compute positional encoding.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
||||
"""
|
||||
self.pe = self.pe.to(x.place)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.position_encoding(offset, x.size(1), False)
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class WhisperPositionalEncoding(PositionalEncoding):
|
||||
"""Sinusoids position encoding used in openai-whisper.encoder"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
|
||||
super().__init__(d_model, dropout_rate, max_len)
|
||||
self.xscale = 1.0
|
||||
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
|
||||
inv_timescales = paddle.exp(
|
||||
x=-log_timescale_increment * paddle.arange(d_model // 2)
|
||||
)
|
||||
scaled_time = (
|
||||
paddle.arange(max_len)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
)
|
||||
pe = paddle.cat([paddle.sin(scaled_time), paddle.cos(scaled_time)], dim=1)
|
||||
delattr(self, "pe")
|
||||
self.register_buffer(name="pe", tensor=pe.unsqueeze(0))
|
||||
|
||||
|
||||
class LearnablePositionalEncoding(PositionalEncoding):
|
||||
"""Learnable position encoding used in openai-whisper.decoder"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
|
||||
super().__init__(d_model, dropout_rate, max_len)
|
||||
self.pe = paddle.nn.parameter.Parameter(paddle.empty(1, max_len, d_model))
|
||||
self.xscale = 1.0
|
||||
|
||||
|
||||
class NoPositionalEncoding(paddle.nn.Layer):
|
||||
"""No position encoding"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.dropout = paddle.nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, x: paddle.Tensor, offset: Union[int, paddle.Tensor] = 0
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Just return zero vector for interface compatibility"""
|
||||
pos_emb = paddle.zeros(1, x.size(1), self.d_model).to(x.place)
|
||||
return self.dropout(x), pos_emb
|
||||
|
||||
def position_encoding(
|
||||
self, offset: Union[int, paddle.Tensor], size: int
|
||||
) -> paddle.Tensor:
|
||||
return paddle.zeros(1, size, self.d_model)
|
||||
|
||||
|
||||
class EspnetRelPositionalEncoding(paddle.nn.Layer):
|
||||
"""Relative positional encoding module (new implementation).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(EspnetRelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = paddle.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(paddle.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: paddle.Tensor):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
if self.pe.dtype != x.dtype or self.pe.place != x.place:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.place)
|
||||
return
|
||||
pe_positive = paddle.zeros(x.size(1), self.d_model)
|
||||
pe_negative = paddle.zeros(x.size(1), self.d_model)
|
||||
position = paddle.arange(0, x.size(1), dtype=paddle.float32).unsqueeze(1)
|
||||
div_term = paddle.exp(
|
||||
x=paddle.arange(0, self.d_model, 2, dtype=paddle.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = paddle.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = paddle.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = paddle.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = paddle.cos(-1 * position * div_term)
|
||||
pe_positive = paddle.flip(x=pe_positive, axis=[0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = paddle.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.place, dtype=x.dtype)
|
||||
|
||||
def forward(
|
||||
self, x: paddle.Tensor, offset: Union[int, paddle.Tensor] = 0
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
def position_encoding(
|
||||
self, offset: Union[int, paddle.Tensor], size: int
|
||||
) -> paddle.Tensor:
|
||||
"""For getting encoding in a streaming fashion
|
||||
|
||||
Attention!!!!!
|
||||
we apply dropout only once at the whole utterance level in a none
|
||||
streaming way, but will call this function several times with
|
||||
increasing input size in a streaming scenario, so the dropout will
|
||||
be applied several times.
|
||||
|
||||
Args:
|
||||
offset (int or torch.tensor): start offset
|
||||
size (int): required size of position encoding
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Corresponding encoding
|
||||
"""
|
||||
if isinstance(offset, int):
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- size
|
||||
- offset
|
||||
+ 1 : self.pe.size(1) // 2
|
||||
+ size
|
||||
+ offset,
|
||||
]
|
||||
elif isinstance(offset, paddle.Tensor):
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- size
|
||||
- offset
|
||||
+ 1 : self.pe.size(1) // 2
|
||||
+ size
|
||||
+ offset,
|
||||
]
|
||||
return pos_emb
|
||||
@ -0,0 +1,352 @@
|
||||
import paddle
|
||||
|
||||
"""Encoder definition."""
|
||||
from typing import Tuple
|
||||
import paddle.nn.functional as F
|
||||
from paddlespeech.t2s.modules.transformer.convolution import ConvolutionModule
|
||||
from paddlespeech.t2s.modules.transformer.encoder_layer import ConformerEncoderLayer
|
||||
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from paddlespeech.t2s.models.CosyVoice.class_utils import (COSYVOICE_ACTIVATION_CLASSES,
|
||||
COSYVOICE_ATTENTION_CLASSES,
|
||||
COSYVOICE_EMB_CLASSES,
|
||||
COSYVOICE_SUBSAMPLE_CLASSES)
|
||||
from paddlespeech.t2s.modules.transformer.mask import add_optional_chunk_mask, make_pad_mask
|
||||
|
||||
|
||||
class Upsample1D(paddle.nn.Layer):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.stride = stride
|
||||
self.conv = paddle.nn.Conv1D(
|
||||
self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, inputs: paddle.Tensor, input_lengths: paddle.Tensor
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
inputs = inputs.unsqueeze(2)
|
||||
outputs = paddle.nn.functional.interpolate(
|
||||
x=inputs, scale_factor=[1, float(self.stride)],mode="nearest"
|
||||
)
|
||||
outputs = outputs.squeeze(2)
|
||||
outputs = F.pad(outputs, [self.stride * 2, 0], value=0.0)
|
||||
outputs = self.conv(outputs)
|
||||
return outputs, input_lengths * self.stride
|
||||
|
||||
|
||||
class PreLookaheadLayer(paddle.nn.Layer):
|
||||
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
self.conv1 = paddle.nn.Conv1D(
|
||||
channels, channels, kernel_size=pre_lookahead_len + 1, stride=1, padding=0
|
||||
)
|
||||
self.conv2 = paddle.nn.Conv1D(
|
||||
channels, channels, kernel_size=3, stride=1, padding=0
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, inputs: paddle.Tensor, context: paddle.Tensor = paddle.zeros([0, 0, 0])
|
||||
) -> paddle.Tensor:
|
||||
"""
|
||||
inputs: (batch_size, seq_len, channels)
|
||||
"""
|
||||
outputs = paddle.transpose(inputs, perm=[0, 2, 1]).contiguous()
|
||||
context = paddle.transpose(context, perm=[0, 2, 1]).contiguous()
|
||||
|
||||
if context.shape[2] == 0:
|
||||
outputs = F.pad(
|
||||
outputs, [0, self.pre_lookahead_len], mode="constant", value=0.0
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.training is False
|
||||
), "you have passed context, make sure that you are running inference mode"
|
||||
assert context.shape[2] == self.pre_lookahead_len
|
||||
outputs = F.pad(
|
||||
paddle.cat([outputs, context], dim=2),
|
||||
[0, self.pre_lookahead_len - context.shape[2]],
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
outputs = paddle.nn.functional.leaky_relu(x=self.conv1(outputs))
|
||||
|
||||
outputs = F.pad(
|
||||
outputs, [self.conv2._kernel_size[0] - 1, 0], mode="constant", value=0.0
|
||||
)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = paddle.transpose(outputs, perm=[0, 2, 1]).contiguous()
|
||||
|
||||
outputs = outputs + inputs
|
||||
return outputs
|
||||
|
||||
|
||||
class UpsampleConformerEncoder(paddle.nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
pos_enc_layer_type: str = "rel_pos",
|
||||
normalize_before: bool = True,
|
||||
static_chunk_size: int = 0,
|
||||
use_dynamic_chunk: bool = False,
|
||||
global_cmvn: paddle.nn.Layer = None,
|
||||
use_dynamic_left_chunk: bool = False,
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
macaron_style: bool = True,
|
||||
selfattention_layer_type: str = "rel_selfattn",
|
||||
activation_type: str = "swish",
|
||||
use_cnn_module: bool = True,
|
||||
cnn_module_kernel: int = 15,
|
||||
causal: bool = False,
|
||||
cnn_module_norm: str = "batch_norm",
|
||||
key_bias: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_size (int): input dim
|
||||
output_size (int): dimension of attention
|
||||
attention_heads (int): the number of heads of multi head attention
|
||||
linear_units (int): the hidden units number of position-wise feed
|
||||
forward
|
||||
num_blocks (int): the number of decoder blocks
|
||||
dropout_rate (float): dropout rate
|
||||
attention_dropout_rate (float): dropout rate in attention
|
||||
positional_dropout_rate (float): dropout rate after adding
|
||||
positional encoding
|
||||
input_layer (str): input layer type.
|
||||
optional [linear, conv2d, conv2d6, conv2d8]
|
||||
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
||||
normalize_before (bool):
|
||||
True: use layer_norm before each sub-block of a layer.
|
||||
False: use layer_norm after each sub-block of a layer.
|
||||
static_chunk_size (int): chunk size for static chunk training and
|
||||
decoding
|
||||
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
||||
training or not, You can only use fixed chunk(chunk_size > 0)
|
||||
or dyanmic chunk size(use_dynamic_chunk = True)
|
||||
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
||||
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
||||
dynamic chunk training
|
||||
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
||||
gradient_checkpointing: rerunning a forward-pass segment for each
|
||||
checkpointed segment during backward.
|
||||
"""
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
self.global_cmvn = global_cmvn
|
||||
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
|
||||
output_size, positional_dropout_rate
|
||||
),
|
||||
)
|
||||
self.normalize_before = normalize_before
|
||||
self.after_norm = paddle.nn.LayerNorm(
|
||||
normalized_shape=output_size, epsilon=1e-05
|
||||
)
|
||||
self.static_chunk_size = static_chunk_size
|
||||
self.use_dynamic_chunk = use_dynamic_chunk
|
||||
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
False,
|
||||
)
|
||||
positionwise_layer_args = (output_size, linear_units, dropout_rate, activation)
|
||||
convolution_layer_args = (
|
||||
output_size,
|
||||
cnn_module_kernel,
|
||||
activation,
|
||||
cnn_module_norm,
|
||||
causal,
|
||||
)
|
||||
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
||||
self.encoders = paddle.nn.LayerList(
|
||||
sublayers=[
|
||||
ConformerEncoderLayer(
|
||||
output_size,
|
||||
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
||||
*encoder_selfattn_layer_args
|
||||
),
|
||||
PositionwiseFeedForward(*positionwise_layer_args),
|
||||
PositionwiseFeedForward(*positionwise_layer_args)
|
||||
if macaron_style
|
||||
else None,
|
||||
ConvolutionModule(*convolution_layer_args)
|
||||
if use_cnn_module
|
||||
else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
||||
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
|
||||
output_size, positional_dropout_rate
|
||||
),
|
||||
)
|
||||
self.up_encoders = paddle.nn.LayerList(
|
||||
sublayers=[
|
||||
ConformerEncoderLayer(
|
||||
output_size,
|
||||
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
||||
*encoder_selfattn_layer_args
|
||||
),
|
||||
PositionwiseFeedForward(*positionwise_layer_args),
|
||||
PositionwiseFeedForward(*positionwise_layer_args)
|
||||
if macaron_style
|
||||
else None,
|
||||
ConvolutionModule(*convolution_layer_args)
|
||||
if use_cnn_module
|
||||
else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
)
|
||||
for _ in range(4)
|
||||
]
|
||||
)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs: paddle.Tensor,
|
||||
xs_lens: paddle.Tensor,
|
||||
context: paddle.Tensor = paddle.zeros([0, 0, 0]),
|
||||
decoding_chunk_size: int = 0,
|
||||
num_decoding_left_chunks: int = -1,
|
||||
streaming: bool = False,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs: padded input tensor (B, T, D)
|
||||
xs_lens: input length (B)
|
||||
decoding_chunk_size: decoding chunk size for dynamic chunk
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
Returns:
|
||||
encoder output tensor xs, and subsampled masks
|
||||
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
||||
masks: torch.Tensor batch padding mask after subsample
|
||||
(B, 1, T' ~= T/subsample_rate)
|
||||
NOTE(xcsong):
|
||||
We pass the `__call__` method of the modules instead of `forward` to the
|
||||
checkpointing API because `__call__` attaches all the hooks of the module.
|
||||
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||
"""
|
||||
T = xs.shape[1]
|
||||
|
||||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)
|
||||
if self.global_cmvn is not None:
|
||||
xs = self.global_cmvn(xs)
|
||||
xs, pos_emb, masks = self.embed(xs, masks)
|
||||
|
||||
if context.shape[1] != 0:
|
||||
assert (
|
||||
self.training is False
|
||||
), "you have passed context, make sure that you are running inference mode"
|
||||
context_masks = paddle.ones(1, 1, context.shape[1]).to(masks)
|
||||
context, _, _ = self.embed(context, context_masks, offset=xs.shape[1])
|
||||
mask_pad = masks
|
||||
chunk_masks = add_optional_chunk_mask(
|
||||
xs,
|
||||
masks,
|
||||
False,
|
||||
False,
|
||||
0,
|
||||
self.static_chunk_size if streaming is True else 0,
|
||||
-1,
|
||||
)
|
||||
xs = self.pre_lookahead_layer(xs, context=context)
|
||||
|
||||
|
||||
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||
#
|
||||
|
||||
xs = paddle.transpose(xs, perm=[0, 2, 1]).contiguous()
|
||||
xs, xs_lens = self.up_layer(xs, xs_lens)
|
||||
xs = paddle.transpose(xs, perm=[0, 2, 1]).contiguous()
|
||||
T = xs.shape[1]
|
||||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1)
|
||||
xs, pos_emb, masks = self.up_embed(xs, masks)
|
||||
mask_pad = masks
|
||||
chunk_masks = add_optional_chunk_mask(
|
||||
xs,
|
||||
masks,
|
||||
False,
|
||||
False,
|
||||
0,
|
||||
self.static_chunk_size * self.up_layer.stride if streaming is True else 0,
|
||||
-1,
|
||||
)
|
||||
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
|
||||
return xs, masks
|
||||
|
||||
def forward_layers(
|
||||
self,
|
||||
xs: paddle.Tensor,
|
||||
chunk_masks: paddle.Tensor,
|
||||
pos_emb: paddle.Tensor,
|
||||
mask_pad: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
for layer in self.encoders:
|
||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||
return xs
|
||||
|
||||
def forward_up_layers(
|
||||
self,
|
||||
xs: paddle.Tensor,
|
||||
chunk_masks: paddle.Tensor,
|
||||
pos_emb: paddle.Tensor,
|
||||
mask_pad: paddle.Tensor,
|
||||
) -> paddle.Tensor:
|
||||
for layer in self.up_encoders:
|
||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||
return xs
|
||||
Binary file not shown.
Loading…
Reference in new issue