diff --git a/hidden_states.pdparams b/hidden_states.pdparams new file mode 100644 index 000000000..94c37f8b5 Binary files /dev/null and b/hidden_states.pdparams differ diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py index 6c7e75c1f..fcf14c111 100644 --- a/paddlespeech/__init__.py +++ b/paddlespeech/__init__.py @@ -13,3 +13,7 @@ # limitations under the License. import _locale _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) + + + + diff --git a/paddlespeech/audio/functional/mel_extract.py b/paddlespeech/audio/functional/mel_extract.py new file mode 100644 index 000000000..8b8830aff --- /dev/null +++ b/paddlespeech/audio/functional/mel_extract.py @@ -0,0 +1,85 @@ +import numpy as np +import paddle +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-05): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-05): + return paddle.log(paddle.clip(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return paddle.exp(x=x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + y = paddle.to_tensor(y.detach().cpu().numpy()) + if paddle.min(paddle.to_tensor(y)) < -1.0: + print("min value is ", paddle.min(paddle.to_tensor(y))) + if paddle.max(paddle.to_tensor(y)) > 1.0: + print("max value is ", paddle.max(paddle.to_tensor(y))) + global mel_basis, hann_window + if f"{str(fmax)}_{str(y.place)}" not in mel_basis: + mel = librosa_mel_fn( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[str(fmax) + "_" + str(y.place)] = ( + paddle.to_tensor(mel).float().to(y.place) + ) + hann_window[str(y.place)] = paddle.audio.functional.get_window( + win_length=win_size, dtype="float32", window="hann" + ).to(y.place) + + y = paddle.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + window = paddle.load("/root/paddlejob/workspace/zhangjinghong/test/PaddleSpeech/matcha_window.pdparams").cuda() + + stft = paddle.signal.stft( + y.cuda(), + n_fft=1920, + hop_length=480, + window=window + ) + + spec = paddle.as_real( + stft + ) + spec = paddle.sqrt(spec.pow(2).sum(-1) + 1e-09) + spec = paddle.matmul(mel_basis[str(fmax) + "_" + str(y.place)], spec) + spec = spectral_normalize_torch(spec) + return spec diff --git a/paddlespeech/cli/tts/cosyvoice.py b/paddlespeech/cli/tts/cosyvoice.py new file mode 100644 index 000000000..13e7768bf --- /dev/null +++ b/paddlespeech/cli/tts/cosyvoice.py @@ -0,0 +1,130 @@ +from paddlespeech.t2s.models.CosyVoice.cosyvoice import CosyVoice2 +import sys +from paddlenlp.transformers import AutoTokenizer, AutoModelForCausalLM +from pathlib import Path +import paddle +import torch +paddle.seed(42) +from paddlespeech.t2s.frontend.CosyVoiceFrontEnd.frontend import CosyVoiceFrontEnd +from paddlespeech.t2s.models.CosyVoice.llm import Qwen2LM,Qwen2Encoder +from paddlespeech.t2s.models.CosyVoice.common import ras_sampling +from paddlespeech.t2s.frontend.CosyVoiceFrontEnd.tokenizer import get_qwen_tokenizer +from hyperpyyaml import load_hyperpyyaml +hyper_yaml_path = "/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle/cosyvoice2.yaml" +with open(hyper_yaml_path, 'r') as f: + configs = load_hyperpyyaml(f) + +# frontend = CosyVoiceFrontEnd( +# lambda:get_qwen_tokenizer('/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B/CosyVoice-BlankEN',skip_special_tokens=True), +# configs['feat_extractor'], +# "/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B/campplus.onnx", +# "/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B/speech_tokenizer_v2.onnx", +# "/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B/spk2info.pt", +# configs['allowed_special'] +# ) +# prompt_wav = "../CosyVoice/gaoziyuan_10.wav" +# tts_text = frontend.text_normalize("定位解决了云存储服务和 data loader 等多个环节的性能波动问题", split=True, text_frontend=True) +# prompt_text = frontend.text_normalize('清晨的阳光透过树叶洒在地面上,微风轻轻吹过,带来花草的香气。街边的咖啡店刚开门,传来阵阵烘焙的香味,让人感到放松与愉快。', split=True, text_frontend=True) +# model_input = frontend.frontend_zero_shot(tts_text, prompt_text, prompt_wav, 24000,'') +# paddle.save(model_input,'model_input.pdparams') +# # cosyvoice_model = CosyVoice2("../CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle") +# model = AutoModelForCausalLM.from_pretrained('/root/paddlejob/workspace/zhangjinghong/test/pretrained/Qwen/Qwen2-0.5B') +# print(type(model)) +# llm = Qwen2Encoder(model) +# qwen_lm = Qwen2LM(896,896,6561,llm,ras_sampling) +# state_dict = paddle.load("/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle/llm.pdparams") +# qwen_lm.set_state_dict(state_dict) + + +# new_dict = torch.load("/root/paddlejob/workspace/zhangjinghong/CosyVoice/data.pt") +# text = new_dict['text'] +# text_len = new_dict['text_len'] +# prompt_text = new_dict['prompt_text'] +# prompt_text_len = new_dict['prompt_text_len'] +# prompt_speech_token = new_dict['prompt_speech_token'] +# prompt_speech_token_len = new_dict['prompt_speech_token_len'] +# embedding = new_dict['embedding'] +# uuid = new_dict['uuid'] + + +# text = model_input['text'] +# text_len = model_input['text_len'] +# prompt_text = model_input['prompt_text'] +# prompt_text_len = model_input['prompt_text_len'] +# prompt_speech_token = model_input['llm_prompt_speech_token'] +# prompt_speech_token_len = model_input['llm_prompt_speech_token_len'] +# embedding = model_input['llm_embedding'] +# uuid = new_dict['uuid'] + +# # 统一设备并转换为Paddle张量 +# device = paddle.CUDAPlace(0) # 使用GPU设备 +# text_tensor = paddle.to_tensor(text).cuda() +# prompt_text_tensor = paddle.to_tensor(prompt_text).cuda() +# prompt_speech_token_tensor = paddle.to_tensor(prompt_speech_token).cuda() +# embedding_tensor = paddle.to_tensor(embedding, dtype='float32').cuda() +# # 确保长度张量也统一设备并正确转换 +# text_len_tensor = text_len.cuda() if hasattr(text_len, 'cuda') else paddle.to_tensor(text_len).cuda() +# prompt_text_len_tensor = prompt_text_len.cuda() if hasattr(prompt_text_len, 'cuda') else paddle.to_tensor(prompt_text_len).cuda() +# prompt_speech_token_len_tensor = prompt_speech_token_len.cuda() if hasattr(prompt_speech_token_len, 'cuda') else paddle.to_tensor(prompt_speech_token_len).cuda() +# token=[] +# for i in qwen_lm.inference(text=text_tensor, +# text_len=text_len_tensor, +# prompt_text=prompt_text_tensor, +# prompt_text_len=prompt_text_len_tensor, +# prompt_speech_token=prompt_speech_token_tensor, +# prompt_speech_token_len=prompt_speech_token_len_tensor, +# embedding=embedding_tensor, +# uuid=uuid): +# token.append(i) +# # print(text) +# print("token: ",i) + +############################################################################################################################ + + +flow = configs['flow'] +flow_state_dict = paddle.load("/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle/flow.pdparams") +flow.set_state_dict(flow_state_dict) +input_dict = torch.load("/root/paddlejob/workspace/zhangjinghong/test/CosyVoice/data.pt") +flow.eval() +tts_mel, _ = flow.inference( + token = paddle.to_tensor(input_dict['token']), + token_len = paddle.to_tensor(input_dict['token_len']), + prompt_token = paddle.to_tensor(input_dict['prompt_token'].cpu().numpy(), dtype = 'int32'), + prompt_token_len = paddle.to_tensor(input_dict['prompt_token_len'].cpu().numpy()), + prompt_feat = paddle.to_tensor(input_dict['prompt_feat'].cpu().numpy()), + prompt_feat_len = paddle.to_tensor(input_dict['prompt_feat_len'].cpu().numpy()), + embedding = paddle.to_tensor(input_dict['embedding'].cpu().numpy()), + streaming = input_dict['streaming'], + finalize = input_dict['finalize'] +) +paddle.save(tts_mel,"tts_mel.pdparams") + +############################################################################################################################ + +from paddlespeech.t2s.models.hifigan.cosy_hifigan import HiFTGenerator +from paddlespeech.t2s.models.hifigan.f0_predictor import ConvRNNF0Predictor +hift_state_dict = paddle.load("/root/paddlejob/workspace/zhangjinghong/CosyVoice/pretrained_models/CosyVoice2-0.5B_paddle/hift.pdparams") +input_mel = paddle.to_tensor(torch.load("../CosyVoice/tts_mel.pt").detach().cpu().numpy()).cuda() +hift_cache_source= paddle.to_tensor(torch.load("../CosyVoice/hift_cache_source.pt").detach().cpu().numpy()).cuda() +# hift_cache_source = paddle.zeros([1, 1, 0]) +hift_configs = configs['hift'] +f0_config = configs['f0_predictor'] +f0_predictor = ConvRNNF0Predictor(**f0_config) + +hift_configs['f0_predictor'] = f0_predictor +hift = HiFTGenerator(**hift_configs) +hift.set_state_dict(hift_state_dict) +# for k,v in hift.state_dict().items(): +# print(k,v.shape) +# print("---"*40) +for k,v in hift_state_dict.items(): + print(k,v.shape) +tts_speech, tts_source = hift.inference(speech_feat=input_mel, cache_source=hift_cache_source) +paddle.save(tts_speech,"speech.pdparams") +# tts_speech,_ = hift.inference(input_dict['tts_mel'],input_dict['cache_source']) + +import torchaudio +import torch +torchaudio.save("paddle.wav",torch.tensor(tts_speech.numpy()),24000) +# sf.write("paddle.wav",tts_speech[0],24000) \ No newline at end of file diff --git a/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/__init__.py b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/file_utils.py b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/file_utils.py new file mode 100644 index 000000000..5679f31a5 --- /dev/null +++ b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/file_utils.py @@ -0,0 +1,110 @@ +import json +import logging +import os + +import paddle +import torchaudio +import paddlespeech + +logging.getLogger("matplotlib").setLevel(logging.WARNING) +logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s") + + +def read_lists(list_file): + lists = [] + with open(list_file, "r", encoding="utf8") as fin: + for line in fin: + lists.append(line.strip()) + return lists + + +def read_json_lists(list_file): + lists = read_lists(list_file) + results = {} + for fn in lists: + with open(fn, "r", encoding="utf8") as fin: + results.update(json.load(fin)) + return results + + +def load_wav(wav, target_sr, min_sr=16000): + speech, sample_rate = torchaudio.load(wav, backend="soundfile") + speech = speech.mean(dim=0, keepdim=True) + if sample_rate != target_sr: + assert ( + sample_rate >= min_sr + ), "wav sample rate {} must be greater than {}".format(sample_rate, target_sr) + speech = torchaudio.transforms.Resample( + orig_freq=sample_rate, new_freq=target_sr + )(speech) + return speech + + +def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): + import tensorrt as trt + + logging.info("Converting onnx to trt...") + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + network = builder.create_network(network_flags) + parser = trt.OnnxParser(network, logger) + config = builder.create_builder_config() + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) + if fp16: + config.set_flag(trt.BuilderFlag.FP16) + profile = builder.create_optimization_profile() + with open(onnx_model, "rb") as f: + if not parser.parse(f.read()): + for error in range(parser.num_errors): + print(parser.get_error(error)) + raise ValueError("failed to parse {}".format(onnx_model)) + for i in range(len(trt_kwargs["input_names"])): + profile.set_shape( + trt_kwargs["input_names"][i], + trt_kwargs["min_shape"][i], + trt_kwargs["opt_shape"][i], + trt_kwargs["max_shape"][i], + ) + tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + input_tensor.dtype = tensor_dtype + for i in range(network.num_outputs): + output_tensor = network.get_output(i) + output_tensor.dtype = tensor_dtype + config.add_optimization_profile(profile) + engine_bytes = builder.build_serialized_network(network, config) + with open(trt_model, "wb") as f: + f.write(engine_bytes) + logging.info("Succesfully convert onnx to trt...") + + +def export_cosyvoice2_vllm(model, model_path, device): + if os.path.exists(model_path): + return + dtype = paddle.bfloat16 + use_bias = True if model.llm_decoder.bias is not None else False + model.llm.model.lm_head = model.llm_decoder + embed_tokens = model.llm.model.model.embed_tokens + model.llm.model.set_input_embeddings(model.speech_embedding) + model.llm.model.to(device) + model.llm.model.to(dtype) + tmp_vocab_size = model.llm.model.config.vocab_size + tmp_tie_embedding = model.llm.model.config.tie_word_embeddings + del model.llm.model.generation_config.eos_token_id + del model.llm.model.config.bos_token_id + del model.llm.model.config.eos_token_id + model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings + model.llm.model.config.tie_word_embeddings = False + model.llm.model.config.use_bias = use_bias + model.llm.model.save_pretrained(model_path) + if use_bias is True: + os.system( + "sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json".format( + os.path.abspath(model_path) + ) + ) + model.llm.model.config.vocab_size = tmp_vocab_size + model.llm.model.config.tie_word_embeddings = tmp_tie_embedding + model.llm.model.set_input_embeddings(embed_tokens) diff --git a/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/frontend.py b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/frontend.py new file mode 100644 index 000000000..6780f81be --- /dev/null +++ b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/frontend.py @@ -0,0 +1,464 @@ +import json +import os +import re +from functools import partial +from typing import Callable, Generator +import librosa +import inflect +import numpy as np +import onnxruntime +import paddle +import paddle +import paddle.nn.functional as F +import numpy as np +from typing import Union, Optional +import paddlespeech +from .func import fbank +def _stft(x, + n_fft, + n_shift, + win_length=None, + window="hann", + center=True, + pad_mode="reflect"): + # x: [Time, Channel] + window = window.cpu().numpy() + x = x.cpu().numpy() + if x.ndim == 1: + single_channel = True + # x: [Time] -> [Time, Channel] + x = x[:, None] + else: + single_channel = False + x = x.astype(np.float32) + + # FIXME(kamo): librosa.stft can't use multi-channel? + # x: [Time, Channel, Freq] + x = np.stack( + [ + librosa.stft( + y=x[:, ch], + n_fft=n_fft, + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, ).T for ch in range(x.shape[1]) + ], + axis=1, ) + + if single_channel: + # x: [Time, Channel, Freq] -> [Time, Freq] + x = x[:, 0] + return x +def log_mel_spectrogram( + audio: Union[str, np.ndarray, paddle.Tensor], + n_mels: int = 80, + padding: int = 0, + device: Optional[str] = None, +): + """ + Compute the log-Mel spectrogram of audio using PaddlePaddle + + Parameters + ---------- + audio: Union[str, np.ndarray, paddle.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[str] + If given, the audio tensor is moved to this device (e.g., 'gpu:0') before STFT + + Returns + ------- + paddle.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + N_FFT = 400 + HOP_LENGTH = 160 + SAMPLE_RATE = 16000 + + if not paddle.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = paddle.to_tensor(audio.detach().cpu().numpy(), dtype='float32') + + if device is not None: + if 'gpu' in device: + place = paddle.CUDAPlace(int(device.split(':')[-1])) + else: + place = paddle.CPUPlace() + audio = audio.place(place) + if padding > 0: + audio = F.pad(audio.unsqueeze(0), [0, padding]).squeeze(0) + import torch + window = paddle.to_tensor(torch.hann_window(N_FFT).cpu().numpy()) + + # window = paddle.audio.functional.get_window( + # 'hann', + # N_FFT, + # dtype='float32' + # ).cuda() + # stft = _stft( + # audio, + # N_FFT, + # HOP_LENGTH, + # window=window + # ) + stft = paddle.signal.stft( + audio, + n_fft=N_FFT, + hop_length=HOP_LENGTH, + window=window + ) + # stft = paddle.to_tensor(stft) + magnitudes = stft[..., :-1].abs().square() + if magnitudes.shape[1] > N_FFT // 2: + magnitudes = magnitudes[:, :N_FFT // 2 + 1, :] + filters = mel_filters(audio.place, n_mels, N_FFT, SAMPLE_RATE) + mel_spec = paddle.matmul(filters, magnitudes) + log_spec = paddle.clip(mel_spec, min=1e-10).log10() + log_spec = paddle.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec.squeeze(0) + +def mel_filters(device: str, n_mels: int = 80, n_fft: int = 400, sr: int = 16000): + + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + + filters_path = "/root/paddlejob/workspace/zhangjinghong/venv_cosy/lib/python3.10/site-packages/whisper/assets/mel_filters.npz" + with np.load(filters_path, allow_pickle=False) as f: + return paddle.to_tensor(f[f"mel_{n_mels}"]).to(device) + + + +try: + import ttsfrd + + use_ttsfrd = True +except ImportError: + print("failed to import ttsfrd, use wetext instead") + from wetext import Normalizer as EnNormalizer + from wetext import Normalizer as ZhNormalizer + + use_ttsfrd = False +from .file_utils import load_wav, logging +from .frontend_utils import (contains_chinese, + is_only_punctuation, + remove_bracket, replace_blank, + replace_corner_mark, + spell_out_number, split_paragraph) + + +class CosyVoiceFrontEnd: + def __init__( + self, + get_tokenizer: Callable, + feat_extractor: Callable, + campplus_model: str, + speech_tokenizer_model: str, + spk2info: str = "", + allowed_special: str = "all", + ): + self.tokenizer = get_tokenizer() + self.feat_extractor = feat_extractor + self.device = 'gpu:0' + option = onnxruntime.SessionOptions() + option.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + option.intra_op_num_threads = 1 + self.campplus_session = onnxruntime.InferenceSession( + campplus_model, sess_options=option, providers=["CPUExecutionProvider"] + ) + self.speech_tokenizer_session = onnxruntime.InferenceSession( + speech_tokenizer_model, + sess_options=option, + providers=[ + "CUDAExecutionProvider" + if paddle.device.is_compiled_with_cuda() + else "CPUExecutionProvider" + ], + ) + if os.path.exists(spk2info): + self.spk2info = paddle.load(path=str(spk2info)) + else: + self.spk2info = {} + self.allowed_special = allowed_special + self.use_ttsfrd = use_ttsfrd + if self.use_ttsfrd: + self.frd = ttsfrd.TtsFrontendEngine() + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + assert ( + self.frd.initialize( + "{}/../../pretrained_models/CosyVoice-ttsfrd/resource".format( + ROOT_DIR + ) + ) + is True + ), "failed to initialize ttsfrd resource" + self.frd.set_lang_type("pinyinvg") + else: + self.zh_tn_model = ZhNormalizer(remove_erhua=False) + self.en_tn_model = EnNormalizer() + self.inflect_parser = inflect.engine() + + def _extract_text_token(self, text): + if isinstance(text, Generator): + logging.info( + "get tts_text generator, will return _extract_text_token_generator!" + ) + return self._extract_text_token_generator(text), paddle.to_tensor( + [0], dtype=paddle.int32 + ).to(self.device) + else: + + text_token = self.tokenizer.encode( + text, allowed_special=self.allowed_special + ) + text_token = paddle.to_tensor([text_token], dtype=paddle.int32).to(self.device) + text_token_len = paddle.to_tensor( + [text_token.shape[1]], dtype=paddle.int32 + ).to(self.device) + return text_token, text_token_len + + def _extract_text_token_generator(self, text_generator): + for text in text_generator: + text_token, _ = self._extract_text_token(text) + for i in range(text_token.shape[1]): + yield text_token[:, i : i + 1] + + def _extract_speech_token(self, prompt_wav): + speech = load_wav(prompt_wav, 16000) + + assert ( + speech.shape[1] / 16000 <= 30 + ), "do not support extract speech token for audio longer than 30s" + feat =log_mel_spectrogram(speech, n_mels=128) + feat = feat.unsqueeze(0) + speech_token = ( + self.speech_tokenizer_session.run( + None, + { + self.speech_tokenizer_session.get_inputs()[0] + .name: feat.detach() + .cpu() + .numpy(), + self.speech_tokenizer_session.get_inputs()[1].name: np.array( + [feat.shape[2]], dtype=np.int32 + ), + }, + )[0] + .flatten() + .tolist() + ) + speech_token = paddle.to_tensor([speech_token], dtype=paddle.int32).to(self.device) + speech_token_len = paddle.to_tensor( + [speech_token.shape[1]], dtype=paddle.int32 + ).to(self.device) + return speech_token, speech_token_len + + def _extract_spk_embedding(self, prompt_wav): + speech = load_wav(prompt_wav, 16000) + speech = paddle.to_tensor(speech.detach().cpu().numpy()).cuda() + feat = fbank( + speech, num_mel_bins=80, dither=0, sample_frequency=16000 + ) + feat = feat - feat.mean(axis=0, keepdim=True) + embedding = ( + self.campplus_session.run( + None, + { + self.campplus_session.get_inputs()[0] + .name: feat.unsqueeze(axis=0) + .cpu() + .numpy() + }, + )[0] + .flatten() + .tolist() + ) + embedding = paddle.to_tensor([embedding]).to(self.device) + return embedding + + def _extract_speech_feat(self, prompt_wav): + speech = load_wav(prompt_wav, 24000) + speech_feat = ( + paddle.transpose(self.feat_extractor(speech).squeeze(axis=0),perm=[0, 1]).to(self.device) + ) + + speech_feat = speech_feat.unsqueeze(axis=0) + speech_feat_len = paddle.to_tensor([speech_feat.shape[1]], dtype=paddle.int32).to( + self.device + ) + return speech_feat, speech_feat_len + + def text_normalize(self, text, split=True, text_frontend=True): + if isinstance(text, Generator): + logging.info("get tts_text generator, will skip text_normalize!") + return [text] + if "<|" in text and "|>" in text: + text_frontend = False + if text_frontend is False or text == "": + return [text] if split is True else text + text = text.strip() + if self.use_ttsfrd: + texts = [ + i["text"] + for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"] + ] + text = "".join(texts) + elif contains_chinese(text): + text = self.zh_tn_model.normalize(text) + text = text.replace("\n", "") + text = replace_blank(text) + text = replace_corner_mark(text) + text = text.replace(".", "。") + text = text.replace(" - ", ",") + text = remove_bracket(text) + text = re.sub("[,,、]+$", "。", text) + texts = list( + split_paragraph( + text, + partial( + self.tokenizer.encode, allowed_special=self.allowed_special + ), + "zh", + token_max_n=80, + token_min_n=60, + merge_len=20, + comma_split=False, + ) + ) + else: + text = self.en_tn_model.normalize(text) + text = spell_out_number(text, self.inflect_parser) + texts = list( + split_paragraph( + text, + partial( + self.tokenizer.encode, allowed_special=self.allowed_special + ), + "en", + token_max_n=80, + token_min_n=60, + merge_len=20, + comma_split=False, + ) + ) + texts = [i for i in texts if not is_only_punctuation(i)] + return texts if split is True else text + + def frontend_sft(self, tts_text, spk_id): + tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) + embedding = self.spk2info[spk_id]["embedding"] + model_input = { + "text": tts_text_token, + "text_len": tts_text_token_len, + "llm_embedding": embedding, + "flow_embedding": embedding, + } + return model_input + + def frontend_zero_shot( + self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id + ): + tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) + + if zero_shot_spk_id == "": + + prompt_text_token, prompt_text_token_len = self._extract_text_token( + prompt_text + ) + + speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav) + speech_feat=paddle.transpose(speech_feat,perm =[0,2,1]) + speech_token, speech_token_len = self._extract_speech_token(prompt_wav) + if resample_rate == 24000: + token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1]) + speech_feat, speech_feat_len[:] = ( + speech_feat[:, : 2 * token_len], + 2 * token_len, + ) + speech_token, speech_token_len[:] = ( + speech_token[:, :token_len], + token_len, + ) + + embedding = self._extract_spk_embedding(prompt_wav) + model_input = { + "prompt_text": prompt_text_token, + "prompt_text_len": prompt_text_token_len, + "llm_prompt_speech_token": speech_token, + "llm_prompt_speech_token_len": speech_token_len, + "flow_prompt_speech_token": speech_token, + "flow_prompt_speech_token_len": speech_token_len, + "prompt_speech_feat": speech_feat, + "prompt_speech_feat_len": speech_feat_len, + "llm_embedding": embedding, + "flow_embedding": embedding, + } + else: + model_input = self.spk2info[zero_shot_spk_id] + model_input["text"] = tts_text_token + model_input["text_len"] = tts_text_token_len + return model_input + + def frontend_cross_lingual( + self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id + ): + model_input = self.frontend_zero_shot( + tts_text, "", prompt_wav, resample_rate, zero_shot_spk_id + ) + del model_input["prompt_text"] + del model_input["prompt_text_len"] + del model_input["llm_prompt_speech_token"] + del model_input["llm_prompt_speech_token_len"] + return model_input + + def frontend_instruct(self, tts_text, spk_id, instruct_text): + model_input = self.frontend_sft(tts_text, spk_id) + del model_input["llm_embedding"] + instruct_text_token, instruct_text_token_len = self._extract_text_token( + instruct_text + ) + model_input["prompt_text"] = instruct_text_token + model_input["prompt_text_len"] = instruct_text_token_len + return model_input + + def frontend_instruct2( + self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id + ): + model_input = self.frontend_zero_shot( + tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id + ) + del model_input["llm_prompt_speech_token"] + del model_input["llm_prompt_speech_token_len"] + return model_input + + def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate): + prompt_speech_token, prompt_speech_token_len = self._extract_speech_token( + prompt_wav + ) + prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat( + prompt_wav + ) + embedding = self._extract_spk_embedding(prompt_wav) + source_speech_token, source_speech_token_len = self._extract_speech_token( + source_speech_16k + ) + model_input = { + "source_speech_token": source_speech_token, + "source_speech_token_len": source_speech_token_len, + "flow_prompt_speech_token": prompt_speech_token, + "flow_prompt_speech_token_len": prompt_speech_token_len, + "prompt_speech_feat": prompt_speech_feat, + "prompt_speech_feat_len": prompt_speech_feat_len, + "flow_embedding": embedding, + } + return model_input diff --git a/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/frontend_utils.py b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/frontend_utils.py new file mode 100644 index 000000000..8ec6fd552 --- /dev/null +++ b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/frontend_utils.py @@ -0,0 +1,121 @@ +import re + +import regex + +chinese_char_pattern = re.compile("[\\u4e00-\\u9fff]+") + + +def contains_chinese(text): + return bool(chinese_char_pattern.search(text)) + + +def replace_corner_mark(text): + text = text.replace("²", "平方") + text = text.replace("³", "立方") + return text + + +def remove_bracket(text): + text = text.replace("(", "").replace(")", "") + text = text.replace("【", "").replace("】", "") + text = text.replace("`", "").replace("`", "") + text = text.replace("——", " ") + return text + + +def spell_out_number(text: str, inflect_parser): + new_text = [] + st = None + for i, c in enumerate(text): + if not c.isdigit(): + if st is not None: + num_str = inflect_parser.number_to_words(text[st:i]) + new_text.append(num_str) + st = None + new_text.append(c) + elif st is None: + st = i + if st is not None and st < len(text): + num_str = inflect_parser.number_to_words(text[st:]) + new_text.append(num_str) + return "".join(new_text) + + +def split_paragraph( + text: str, + tokenize, + lang="zh", + token_max_n=80, + token_min_n=60, + merge_len=20, + comma_split=False, +): + def calc_utt_length(_text: str): + if lang == "zh": + return len(_text) + else: + return len(tokenize(_text)) + + def should_merge(_text: str): + if lang == "zh": + return len(_text) < merge_len + else: + return len(tokenize(_text)) < merge_len + + if lang == "zh": + pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"] + else: + pounc = [".", "?", "!", ";", ":"] + if comma_split: + pounc.extend([",", ","]) + if text[-1] not in pounc: + if lang == "zh": + text += "。" + else: + text += "." + st = 0 + utts = [] + for i, c in enumerate(text): + if c in pounc: + if len(text[st:i]) > 0: + utts.append(text[st:i] + c) + if i + 1 < len(text) and text[i + 1] in ['"', "”"]: + tmp = utts.pop(-1) + utts.append(tmp + text[i + 1]) + st = i + 2 + else: + st = i + 1 + final_utts = [] + cur_utt = "" + for utt in utts: + if ( + calc_utt_length(cur_utt + utt) > token_max_n + and calc_utt_length(cur_utt) > token_min_n + ): + final_utts.append(cur_utt) + cur_utt = "" + cur_utt = cur_utt + utt + if len(cur_utt) > 0: + if should_merge(cur_utt) and len(final_utts) != 0: + final_utts[-1] = final_utts[-1] + cur_utt + else: + final_utts.append(cur_utt) + return final_utts + + +def replace_blank(text: str): + out_str = [] + for i, c in enumerate(text): + if c == " ": + if (text[i + 1].isascii() and text[i + 1] != " ") and ( + text[i - 1].isascii() and text[i - 1] != " " + ): + out_str.append(c) + else: + out_str.append(c) + return "".join(out_str) + + +def is_only_punctuation(text): + punctuation_pattern = "^[\\p{P}\\p{S}]*$" + return bool(regex.fullmatch(punctuation_pattern, text)) diff --git a/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/func.py b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/func.py new file mode 100644 index 000000000..82a4ca6da --- /dev/null +++ b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/func.py @@ -0,0 +1,747 @@ +import math +from typing import Tuple + +import paddle + +import paddlespeech + +__all__ = [ + "get_mel_banks", + "inverse_mel_scale", + "inverse_mel_scale_scalar", + "mel_scale", + "mel_scale_scalar", + "spectrogram", + "fbank", + "mfcc", + "vtln_warp_freq", + "vtln_warp_mel_freq", +] +EPSILON = paddle.to_tensor(paddle.finfo(paddle.float32).eps) +MILLISECONDS_TO_SECONDS = 0.001 +HAMMING = "hamming" +HANNING = "hanning" +POVEY = "povey" +RECTANGULAR = "rectangular" +BLACKMAN = "blackman" +WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN] + + +def _get_epsilon(device, dtype): + return EPSILON.to(device=device, dtype=dtype) + + +def _next_power_of_2(x: int) -> int: + """Returns the smallest power of 2 that is greater than x""" + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +def _get_strided( + waveform: paddle.Tensor, window_size: int, window_shift: int, snip_edges: bool +) -> paddle.Tensor: + """Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) + representing how the window is shifted along the waveform. Each row is a frame. + + Args: + waveform (Tensor): Tensor of size ``num_samples`` + window_size (int): Frame length + window_shift (int): Frame shift + snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. + + Returns: + Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame + """ + assert waveform.dim() == 1 + num_samples = waveform.shape[0] + strides = window_shift * waveform.strides[0], waveform.strides[0] + if snip_edges: + if num_samples < window_size: + return paddle.empty((0, 0), dtype=waveform.dtype, device=waveform.device) + else: + m = 1 + (num_samples - window_size) // window_shift + else: + reversed_waveform = paddle.flip(x=waveform, axis=[0]) + m = (num_samples + window_shift // 2) // window_shift + pad = window_size // 2 - window_shift // 2 + pad_right = reversed_waveform + if pad > 0: + pad_left = reversed_waveform[-pad:] + waveform = paddle.cat((pad_left, waveform, pad_right), axis=0) + else: + waveform = paddle.cat((waveform[-pad:], pad_right), axis=0) + sizes = m, window_size + return waveform.as_strided(shape=sizes, stride=strides) + + +def _feature_window_function( + window_type: str, + window_size: int, + blackman_coeff: float, + device: paddle.device, + dtype: int, +) -> paddle.Tensor: + """Returns a window function with the given type and size""" + if window_type == HANNING: + return paddle.audio.functional.get_window( + win_length=window_size, fftbins=False, dtype=dtype, window="hann" + ) + elif window_type == HAMMING: + return paddle.hamming_window( + window_size, + periodic=False, + alpha=0.54, + beta=0.46, + device=device, + dtype=dtype, + ) + elif window_type == POVEY: + return paddle.audio.functional.get_window( + win_length=window_size, fftbins=False, dtype=dtype, window="hann" + ).pow(0.85) + elif window_type == RECTANGULAR: + return paddle.ones(window_size, device=device, dtype=dtype) + elif window_type == BLACKMAN: + a = 2 * math.pi / (window_size - 1) + window_function = paddle.arange(window_size, device=device, dtype=dtype) + return ( + blackman_coeff + - 0.5 * paddle.cos(a * window_function) + + (0.5 - blackman_coeff) * paddle.cos(2 * a * window_function) + ).to(device=device, dtype=dtype) + else: + raise Exception("Invalid window type " + window_type) + + +def _get_log_energy( + strided_input: paddle.Tensor, epsilon: paddle.Tensor, energy_floor: float +) -> paddle.Tensor: + """Returns the log energy of size (m) for a strided_input (m,*)""" + place, dtype = strided_input.place, strided_input.dtype + log_energy = paddle.maximum(strided_input.pow(2).sum(axis=1), epsilon).log() + if energy_floor == 0.0: + return log_energy + return paddle.maximum( + log_energy, paddle.to_tensor(math.log(energy_floor), place=place, dtype=dtype) + ) + + +def _get_waveform_and_window_properties( + waveform: paddle.Tensor, + channel: int, + sample_frequency: float, + frame_shift: float, + frame_length: float, + round_to_power_of_two: bool, + preemphasis_coefficient: float, +) -> Tuple[paddle.Tensor, int, int, int]: + """Gets the waveform and window properties""" + channel = max(channel, 0) + assert channel < waveform.shape[0], "Invalid channel {} for size {}".format( + channel, waveform.shape[0] + ) + waveform = waveform[channel, :] + window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS) + window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS) + padded_window_size = ( + _next_power_of_2(window_size) if round_to_power_of_two else window_size + ) + assert ( + 2 <= window_size <= len(waveform) + ), "choose a window size {} that is [2, {}]".format(window_size, len(waveform)) + assert 0 < window_shift, "`window_shift` must be greater than 0" + assert ( + padded_window_size % 2 == 0 + ), "the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`" + assert ( + 0.0 <= preemphasis_coefficient <= 1.0 + ), "`preemphasis_coefficient` must be between [0,1]" + assert sample_frequency > 0, "`sample_frequency` must be greater than zero" + return waveform, window_shift, window_size, padded_window_size + + +def _get_window( + waveform: paddle.Tensor, + padded_window_size: int, + window_size: int, + window_shift: int, + window_type: str, + blackman_coeff: float, + snip_edges: bool, + raw_energy: bool, + energy_floor: float, + dither: float, + remove_dc_offset: bool, + preemphasis_coefficient: float, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Gets a window and its log energy + + Returns: + (Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m) + """ + + place, dtype = waveform.place, waveform.dtype + epsilon = _get_epsilon(place, dtype) + strided_input = _get_strided(waveform, window_size, window_shift, snip_edges) + if dither != 0.0: + rand_gauss = paddle.randn(strided_input.shape, place=place, dtype=dtype) + strided_input = strided_input + rand_gauss * dither + if remove_dc_offset: + row_means = paddle.mean(strided_input, axis=1).unsqueeze(1) + strided_input = strided_input - row_means + if raw_energy: + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) + if preemphasis_coefficient != 0.0: + offset_strided_input = paddle.nn.functional.pad( + strided_input.unsqueeze(0), (1, 0), mode="replicate" + ).squeeze(0) + strided_input = ( + strided_input - preemphasis_coefficient * offset_strided_input[:, :-1] + ) + window_function = _feature_window_function( + window_type, window_size, blackman_coeff, place, dtype + ).unsqueeze(0) + strided_input = strided_input * window_function + if padded_window_size != window_size: + padding_right = padded_window_size - window_size + strided_input = paddle.nn.functional.pad( + strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0 + ).squeeze(0) + if not raw_energy: + signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) + return strided_input, signal_log_energy + + +def _subtract_column_mean(tensor: paddle.Tensor, subtract_mean: bool) -> paddle.Tensor: + if subtract_mean: + col_means = paddle.mean(tensor, axis=0).unsqueeze(0) + tensor = tensor - col_means + return tensor + + +def spectrogram( + waveform: paddle.Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + min_duration: float = 0.0, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + window_type: str = POVEY, +) -> paddle.Tensor: + """Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's + compute-spectrogram-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``'povey'``) + + Returns: + Tensor: A spectrogram identical to what Kaldi would output. The shape is + (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided + """ + device, dtype = waveform.device, waveform.dtype + epsilon = _get_epsilon(device, dtype) + ( + waveform, + window_shift, + window_size, + padded_window_size, + ) = _get_waveform_and_window_properties( + waveform, + channel, + sample_frequency, + frame_shift, + frame_length, + round_to_power_of_two, + preemphasis_coefficient, + ) + if len(waveform) < min_duration * sample_frequency: + return paddle.empty(0) + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + fft = paddle.fft.rfft(strided_input) + power_spectrum = paddle.maximum(fft.abs().pow(2.0), epsilon).log() + power_spectrum[:, 0] = signal_log_energy + power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) + return power_spectrum + + +def inverse_mel_scale_scalar(mel_freq: float) -> float: + return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0) + + +def inverse_mel_scale(mel_freq: paddle.Tensor) -> paddle.Tensor: + return 700.0 * ((mel_freq / 1127.0).exp() - 1.0) + + +def mel_scale_scalar(freq: float) -> float: + return 1127.0 * math.log(1.0 + freq / 700.0) + + +def mel_scale(freq: paddle.Tensor) -> paddle.Tensor: + return 1127.0 * (1.0 + freq / 700.0).log() + + +def vtln_warp_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq: float, + high_freq: float, + vtln_warp_factor: float, + freq: paddle.Tensor, +) -> paddle.Tensor: + """This computes a VTLN warping function that is not the same as HTK's one, + but has similar inputs (this function has the advantage of never producing + empty bins). + + This function computes a warp function F(freq), defined between low_freq + and high_freq inclusive, with the following properties: + F(low_freq) == low_freq + F(high_freq) == high_freq + The function is continuous and piecewise linear with two inflection + points. + The lower inflection point (measured in terms of the unwarped + frequency) is at frequency l, determined as described below. + The higher inflection point is at a frequency h, determined as + described below. + If l <= f <= h, then F(f) = f/vtln_warp_factor. + If the higher inflection point (measured in terms of the unwarped + frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + Since (by the last point) F(h) == h/vtln_warp_factor, then + max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + = vtln_high_cutoff * min(1, vtln_warp_factor). + If the lower inflection point (measured in terms of the unwarped + frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + = vtln_low_cutoff * max(1, vtln_warp_factor) + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + freq (Tensor): given frequency in Hz + + Returns: + Tensor: Freq after vtln warp + """ + assert ( + vtln_low_cutoff > low_freq + ), "be sure to set the vtln_low option higher than low_freq" + assert ( + vtln_high_cutoff < high_freq + ), "be sure to set the vtln_high option lower than high_freq [or negative]" + l = vtln_low_cutoff * max(1.0, vtln_warp_factor) + h = vtln_high_cutoff * min(1.0, vtln_warp_factor) + scale = 1.0 / vtln_warp_factor + Fl = scale * l + Fh = scale * h + assert l > low_freq and h < high_freq + scale_left = (Fl - low_freq) / (l - low_freq) + scale_right = (high_freq - Fh) / (high_freq - h) + res = paddle.empty_like(freq) + outside_low_high_freq = paddle.lt(freq, low_freq) | paddle.gt(freq, high_freq) + before_l = paddle.lt(freq, l) + before_h = paddle.lt(freq, h) + after_h = paddle.ge(freq, h) + res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq) + res[before_h] = scale * freq[before_h] + res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq) + res[outside_low_high_freq] = freq[outside_low_high_freq] + return res + + +def vtln_warp_mel_freq( + vtln_low_cutoff: float, + vtln_high_cutoff: float, + low_freq, + high_freq: float, + vtln_warp_factor: float, + mel_freq: paddle.Tensor, +) -> paddle.Tensor: + """ + Args: + vtln_low_cutoff (float): Lower frequency cutoffs for VTLN + vtln_high_cutoff (float): Upper frequency cutoffs for VTLN + low_freq (float): Lower frequency cutoffs in mel computation + high_freq (float): Upper frequency cutoffs in mel computation + vtln_warp_factor (float): Vtln warp factor + mel_freq (Tensor): Given frequency in Mel + + Returns: + Tensor: ``mel_freq`` after vtln warp + """ + return mel_scale( + vtln_warp_freq( + vtln_low_cutoff, + vtln_high_cutoff, + low_freq, + high_freq, + vtln_warp_factor, + inverse_mel_scale(mel_freq), + ) + ) + + +def get_mel_banks( + num_bins: int, + window_length_padded: int, + sample_freq: float, + low_freq: float, + high_freq: float, + vtln_low: float, + vtln_high: float, + vtln_warp_factor: float, +) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + Returns: + (Tensor, Tensor): The tuple consists of ``bins`` (which is + melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is + center frequencies of bins of size (``num_bins``)). + """ + assert num_bins > 3, "Must have at least 3 mel bins" + assert window_length_padded % 2 == 0 + num_fft_bins = window_length_padded / 2 + nyquist = 0.5 * sample_freq + if high_freq <= 0.0: + high_freq += nyquist + assert ( + 0.0 <= low_freq < nyquist + and 0.0 < high_freq <= nyquist + and low_freq < high_freq + ), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format( + low_freq, high_freq, nyquist + ) + fft_bin_width = sample_freq / window_length_padded + mel_low_freq = mel_scale_scalar(low_freq) + mel_high_freq = mel_scale_scalar(high_freq) + mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1) + if vtln_high < 0.0: + vtln_high += nyquist + assert ( + vtln_warp_factor == 1.0 + or low_freq < vtln_low < high_freq + and 0.0 < vtln_high < high_freq + and vtln_low < vtln_high + ), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format( + vtln_low, vtln_high, low_freq, high_freq + ) + bin = paddle.arange(num_bins).unsqueeze(1) + left_mel = mel_low_freq + bin * mel_freq_delta + center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta + right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta + if vtln_warp_factor != 1.0: + left_mel = vtln_warp_mel_freq( + vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel + ) + center_mel = vtln_warp_mel_freq( + vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel + ) + right_mel = vtln_warp_mel_freq( + vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel + ) + center_freqs = inverse_mel_scale(center_mel) + mel = mel_scale(fft_bin_width * paddle.arange(num_fft_bins)).unsqueeze(0) + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + if vtln_warp_factor == 1.0: + bins = paddle.maximum( + paddle.zeros(1), paddle.minimum(up_slope, down_slope) + ) + else: + bins = paddle.zeros_like(up_slope) + up_idx = paddle.gt(mel, left_mel) & paddle.le(mel, center_mel) + down_idx = paddle.gt(mel, center_mel) & paddle.lt(mel, right_mel) + bins[up_idx] = up_slope[up_idx] + bins[down_idx] = down_slope[down_idx] + return bins, center_freqs + + +def fbank( + waveform: paddle.Tensor, + blackman_coeff: float = 0.42, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + use_log_fbank: bool = True, + use_power: bool = True, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> paddle.Tensor: + device, dtype = waveform.place, waveform.dtype + ( + waveform, + window_shift, + window_size, + padded_window_size, + ) = _get_waveform_and_window_properties( + waveform, + channel, + sample_frequency, + frame_shift, + frame_length, + round_to_power_of_two, + preemphasis_coefficient, + ) + if len(waveform) < min_duration * sample_frequency: + return paddle.empty(0, place=place, dtype=dtype) + strided_input, signal_log_energy = _get_window( + waveform, + padded_window_size, + window_size, + window_shift, + window_type, + blackman_coeff, + snip_edges, + raw_energy, + energy_floor, + dither, + remove_dc_offset, + preemphasis_coefficient, + ) + spectrum = paddle.fft.rfft(strided_input).abs() + if use_power: + spectrum = spectrum.pow(2.0) + mel_energies, _ = get_mel_banks( + num_mel_bins, + padded_window_size, + sample_frequency, + low_freq, + high_freq, + vtln_low, + vtln_high, + vtln_warp, + ) + mel_energies = mel_energies.to(device=device, dtype=dtype) + mel_energies = paddle.nn.functional.pad( + mel_energies, (0, 1), mode="constant", value=0 + ) + mel_energies = paddle.mm(input=spectrum, mat2=mel_energies.T) + if use_log_fbank: + mel_energies = paddle.maximum( + mel_energies, _get_epsilon(device, dtype) + ).log() + if use_energy: + signal_log_energy = signal_log_energy.unsqueeze(1) + if htk_compat: + mel_energies = paddle.cat((mel_energies, signal_log_energy), axis=1) + else: + mel_energies = paddle.cat((signal_log_energy, mel_energies), axis=1) + mel_energies = _subtract_column_mean(mel_energies, subtract_mean) + return mel_energies + +def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> paddle.Tensor: + dct_matrix = paddle.audio.functional.create_dct( + n_mfcc=num_mel_bins, + n_mels=num_mel_bins, + norm='ortho' + ) + first_col = paddle.full( + shape=[num_mel_bins, 1], + fill_value=math.sqrt(1 / float(num_mel_bins)) + ) + if num_ceps > 1: + other_cols = dct_matrix[:, 1:num_ceps] + dct_matrix = paddle.concat([first_col, other_cols], axis=1) + else: + dct_matrix = first_col + return dct_matrix + + +def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> paddle.Tensor: + i = paddle.arange(num_ceps) + return 1.0 + 0.5 * cepstral_lifter * paddle.sin(math.pi * i / cepstral_lifter) + + +def mfcc( + waveform: paddle.Tensor, + blackman_coeff: float = 0.42, + cepstral_lifter: float = 22.0, + channel: int = -1, + dither: float = 0.0, + energy_floor: float = 1.0, + frame_length: float = 25.0, + frame_shift: float = 10.0, + high_freq: float = 0.0, + htk_compat: bool = False, + low_freq: float = 20.0, + num_ceps: int = 13, + min_duration: float = 0.0, + num_mel_bins: int = 23, + preemphasis_coefficient: float = 0.97, + raw_energy: bool = True, + remove_dc_offset: bool = True, + round_to_power_of_two: bool = True, + sample_frequency: float = 16000.0, + snip_edges: bool = True, + subtract_mean: bool = False, + use_energy: bool = False, + vtln_high: float = -500.0, + vtln_low: float = 100.0, + vtln_warp: float = 1.0, + window_type: str = POVEY, +) -> paddle.Tensor: + """Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's + compute-mfcc-feats. + + Args: + waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) + blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``) + channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) + dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``) + energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: + this floor is applied to the zeroth component, representing the total signal energy. The floor on the + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``1.0``) + frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) + (Default: ``0.0``) + htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible + features (need to change other parameters). (Default: ``False``) + low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``) + num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``) + min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``) + round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input + to FFT. (Default: ``True``) + sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if + specified there) (Default: ``16000.0``) + snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit + in the file, and the number of frames depends on the frame_length. If False, the number of frames + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) + subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do + it this way. (Default: ``False``) + use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') + (Default: ``"povey"``) + + Returns: + Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``) + where m is calculated in _get_strided + """ + assert ( + num_ceps <= num_mel_bins + ), "num_ceps cannot be larger than num_mel_bins: %d vs %d" % ( + num_ceps, + num_mel_bins, + ) + device, dtype = waveform.device, waveform.dtype + feature = fbank( + waveform=waveform, + blackman_coeff=blackman_coeff, + channel=channel, + dither=dither, + energy_floor=energy_floor, + frame_length=frame_length, + frame_shift=frame_shift, + high_freq=high_freq, + htk_compat=htk_compat, + low_freq=low_freq, + min_duration=min_duration, + num_mel_bins=num_mel_bins, + preemphasis_coefficient=preemphasis_coefficient, + raw_energy=raw_energy, + remove_dc_offset=remove_dc_offset, + round_to_power_of_two=round_to_power_of_two, + sample_frequency=sample_frequency, + snip_edges=snip_edges, + subtract_mean=False, + use_energy=use_energy, + use_log_fbank=True, + use_power=True, + vtln_high=vtln_high, + vtln_low=vtln_low, + vtln_warp=vtln_warp, + window_type=window_type, + ) + if use_energy: + signal_log_energy = feature[:, num_mel_bins if htk_compat else 0] + mel_offset = int(not htk_compat) + feature = feature[:, mel_offset : num_mel_bins + mel_offset] + dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device) + feature = feature.matmul(dct_matrix) + if cepstral_lifter != 0.0: + lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0) + feature *= lifter_coeffs.to(device=device, dtype=dtype) + if use_energy: + feature[:, 0] = signal_log_energy + if htk_compat: + energy = feature[:, 0].unsqueeze(1) + feature = feature[:, 1:] + if not use_energy: + energy *= math.sqrt(2) + feature = paddle.cat((feature, energy), axis=1) + feature = _subtract_column_mean(feature, subtract_mean) + return feature diff --git a/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/tokenizer.py b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/tokenizer.py new file mode 100644 index 000000000..779ede8d9 --- /dev/null +++ b/paddlespeech/t2s/frontend/CosyVoiceFrontEnd/tokenizer.py @@ -0,0 +1,578 @@ +import base64 +import os +from functools import lru_cache +from typing import Optional + +import paddle + +import tiktoken +from whisper.tokenizer import Tokenizer +from paddlenlp.transformers import AutoTokenizer +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", + "yue": "cantonese", + "minnan": "minnan", + "wuyu": "wuyu", + "dialect": "dialect", + "zh/en": "zh/en", + "en/zh": "en/zh", +} +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", + "mandarin": "zh", +} +AUDIO_EVENT = { + "ASR": "ASR", + "AED": "AED", + "SER": "SER", + "Speech": "Speech", + "/Speech": "/Speech", + "BGM": "BGM", + "/BGM": "/BGM", + "Laughter": "Laughter", + "/Laughter": "/Laughter", + "Applause": "Applause", + "/Applause": "/Applause", +} +EMOTION = {"HAPPY": "HAPPY", "SAD": "SAD", "ANGRY": "ANGRY", "NEUTRAL": "NEUTRAL"} +TTS_Vocal_Token = { + "TTS/B": "TTS/B", + "TTS/O": "TTS/O", + "TTS/Q": "TTS/Q", + "TTS/A": "TTS/A", + "TTS/CO": "TTS/CO", + "TTS/CL": "TTS/CL", + "TTS/H": "TTS/H", + **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}, +} + + +@lru_cache(maxsize=None) +def get_encoding(name: str = "gpt2", num_languages: int = 99): + vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") + ranks = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in open(vocab_path) if line) + } + n_vocab = len(ranks) + special_tokens = {} + specials = [ + "<|endoftext|>", + "<|startoftranscript|>", + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], + *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())], + *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())], + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nospeech|>", + "<|notimestamps|>", + *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], + *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], + *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], + ] + for token in specials: + special_tokens[token] = n_vocab + n_vocab += 1 + return tiktoken.Encoding( + name=os.path.basename(vocab_path), + explicit_n_vocab=n_vocab, + pat_str="'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", + mergeable_ranks=ranks, + special_tokens=special_tokens, + ) + + +@lru_cache(maxsize=None) +def get_tokenizer( + multilingual: bool, + *, + num_languages: int = 99, + language: Optional[str] = None, + task: Optional[str] = None, +) -> Tokenizer: + if language is not None: + language = language.lower() + if language not in LANGUAGES: + if language in TO_LANGUAGE_CODE: + language = TO_LANGUAGE_CODE[language] + else: + raise ValueError(f"Unsupported language: {language}") + if multilingual: + encoding_name = "multilingual_zh_ja_yue_char_del" + language = language or "en" + task = task or "transcribe" + else: + encoding_name = "gpt2" + language = None + task = None + encoding = get_encoding(name=encoding_name, num_languages=num_languages) + return Tokenizer( + encoding=encoding, num_languages=num_languages, language=language, task=task + ) + + +class CosyVoice2Tokenizer: + def __init__(self, token_path, skip_special_tokens=True): + super().__init__() + special_tokens = { + "eos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + "additional_special_tokens": [ + "<|im_start|>", + "<|im_end|>", + "<|endofprompt|>", + "[breath]", + "", + "", + "[noise]", + "[laughter]", + "[cough]", + "[clucking]", + "[accent]", + "[quick_breath]", + "", + "", + "[hissing]", + "[sigh]", + "[vocalized-noise]", + "[lipsmack]", + "[mn]", + ], + } + self.special_tokens = special_tokens + self.tokenizer = AutoTokenizer.from_pretrained(token_path) + self.tokenizer.add_special_tokens(special_tokens) + self.skip_special_tokens = skip_special_tokens + + def encode(self, text, **kwargs): + tokens = self.tokenizer(text) + tokens = tokens["input_ids"][0] + return tokens + + def decode(self, tokens): + tokens = paddle.tensor(tokens, dtype=paddle.int64) + text = self.tokenizer.batch_decode( + [tokens], skip_special_tokens=self.skip_special_tokens + )[0] + return text + + +class CosyVoice3Tokenizer(CosyVoice2Tokenizer): + def __init__(self, token_path, skip_special_tokens=True): + special_tokens = { + "eos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + "additional_special_tokens": [ + "<|im_start|>", + "<|im_end|>", + "<|endofprompt|>", + "[breath]", + "", + "", + "[noise]", + "[laughter]", + "[cough]", + "[clucking]", + "[accent]", + "[quick_breath]", + "", + "", + "[hissing]", + "[sigh]", + "[vocalized-noise]", + "[lipsmack]", + "[mn]", + "<|endofsystem|>", + "[AA]", + "[AA0]", + "[AA1]", + "[AA2]", + "[AE]", + "[AE0]", + "[AE1]", + "[AE2]", + "[AH]", + "[AH0]", + "[AH1]", + "[AH2]", + "[AO]", + "[AO0]", + "[AO1]", + "[AO2]", + "[AW]", + "[AW0]", + "[AW1]", + "[AW2]", + "[AY]", + "[AY0]", + "[AY1]", + "[AY2]", + "[B]", + "[CH]", + "[D]", + "[DH]", + "[EH]", + "[EH0]", + "[EH1]", + "[EH2]", + "[ER]", + "[ER0]", + "[ER1]", + "[ER2]", + "[EY]", + "[EY0]", + "[EY1]", + "[EY2]", + "[F]", + "[G]", + "[HH]", + "[IH]", + "[IH0]", + "[IH1]", + "[IH2]", + "[IY]", + "[IY0]", + "[IY1]", + "[IY2]", + "[JH]", + "[K]", + "[L]", + "[M]", + "[N]", + "[NG]", + "[OW]", + "[OW0]", + "[OW1]", + "[OW2]", + "[OY]", + "[OY0]", + "[OY1]", + "[OY2]", + "[P]", + "[R]", + "[S]", + "[SH]", + "[T]", + "[TH]", + "[UH]", + "[UH0]", + "[UH1]", + "[UH2]", + "[UW]", + "[UW0]", + "[UW1]", + "[UW2]", + "[V]", + "[W]", + "[Y]", + "[Z]", + "[ZH]", + "[a]", + "[ai]", + "[an]", + "[ang]", + "[ao]", + "[b]", + "[c]", + "[ch]", + "[d]", + "[e]", + "[ei]", + "[en]", + "[eng]", + "[f]", + "[g]", + "[h]", + "[i]", + "[ian]", + "[in]", + "[ing]", + "[iu]", + "[ià]", + "[iàn]", + "[iàng]", + "[iào]", + "[iá]", + "[ián]", + "[iáng]", + "[iáo]", + "[iè]", + "[ié]", + "[iòng]", + "[ióng]", + "[iù]", + "[iú]", + "[iā]", + "[iān]", + "[iāng]", + "[iāo]", + "[iē]", + "[iě]", + "[iōng]", + "[iū]", + "[iǎ]", + "[iǎn]", + "[iǎng]", + "[iǎo]", + "[iǒng]", + "[iǔ]", + "[j]", + "[k]", + "[l]", + "[m]", + "[n]", + "[o]", + "[ong]", + "[ou]", + "[p]", + "[q]", + "[r]", + "[s]", + "[sh]", + "[t]", + "[u]", + "[uang]", + "[ue]", + "[un]", + "[uo]", + "[uà]", + "[uài]", + "[uàn]", + "[uàng]", + "[uá]", + "[uái]", + "[uán]", + "[uáng]", + "[uè]", + "[ué]", + "[uì]", + "[uí]", + "[uò]", + "[uó]", + "[uā]", + "[uāi]", + "[uān]", + "[uāng]", + "[uē]", + "[uě]", + "[uī]", + "[uō]", + "[uǎ]", + "[uǎi]", + "[uǎn]", + "[uǎng]", + "[uǐ]", + "[uǒ]", + "[vè]", + "[w]", + "[x]", + "[y]", + "[z]", + "[zh]", + "[à]", + "[ài]", + "[àn]", + "[àng]", + "[ào]", + "[á]", + "[ái]", + "[án]", + "[áng]", + "[áo]", + "[è]", + "[èi]", + "[èn]", + "[èng]", + "[èr]", + "[é]", + "[éi]", + "[én]", + "[éng]", + "[ér]", + "[ì]", + "[ìn]", + "[ìng]", + "[í]", + "[ín]", + "[íng]", + "[ò]", + "[òng]", + "[òu]", + "[ó]", + "[óng]", + "[óu]", + "[ù]", + "[ùn]", + "[ú]", + "[ún]", + "[ā]", + "[āi]", + "[ān]", + "[āng]", + "[āo]", + "[ē]", + "[ēi]", + "[ēn]", + "[ēng]", + "[ě]", + "[ěi]", + "[ěn]", + "[ěng]", + "[ěr]", + "[ī]", + "[īn]", + "[īng]", + "[ō]", + "[ōng]", + "[ōu]", + "[ū]", + "[ūn]", + "[ǎ]", + "[ǎi]", + "[ǎn]", + "[ǎng]", + "[ǎo]", + "[ǐ]", + "[ǐn]", + "[ǐng]", + "[ǒ]", + "[ǒng]", + "[ǒu]", + "[ǔ]", + "[ǔn]", + "[ǘ]", + "[ǚ]", + "[ǜ]", + ], + } + self.special_tokens = special_tokens + self.tokenizer = transformers.AutoTokenizer.from_pretrained(token_path) + self.tokenizer.add_special_tokens(special_tokens) + self.skip_special_tokens = skip_special_tokens + + +@lru_cache(maxsize=None) +def get_qwen_tokenizer( + token_path: str, skip_special_tokens: bool, version: str = "cosyvoice2" +): + if version == "cosyvoice2": + return CosyVoice2Tokenizer( + token_path=token_path, skip_special_tokens=skip_special_tokens + ) + elif version == "cosyvoice3": + return CosyVoice3Tokenizer( + token_path=token_path, skip_special_tokens=skip_special_tokens + ) + else: + raise ValueError diff --git a/paddlespeech/t2s/models/CosyVoice/__init__.py b/paddlespeech/t2s/models/CosyVoice/__init__.py new file mode 100644 index 000000000..68ca8e7d8 --- /dev/null +++ b/paddlespeech/t2s/models/CosyVoice/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .cosyvoice import * diff --git a/paddlespeech/t2s/models/CosyVoice/class_utils.py b/paddlespeech/t2s/models/CosyVoice/class_utils.py new file mode 100644 index 000000000..d2c9eb1de --- /dev/null +++ b/paddlespeech/t2s/models/CosyVoice/class_utils.py @@ -0,0 +1,33 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlespeech.t2s.modules.transformer.activation import Swish +from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention +from paddlespeech.t2s.modules.transformer.embedding import EspnetRelPositionalEncoding +from paddlespeech.t2s.modules.transformer.subsampling import LinearNoSubsampling + + +COSYVOICE_ACTIVATION_CLASSES = { + "swish": Swish +} +COSYVOICE_SUBSAMPLE_CLASSES = { + "linear": LinearNoSubsampling, +} +COSYVOICE_EMB_CLASSES = { + "rel_pos_espnet": EspnetRelPositionalEncoding, +} +COSYVOICE_ATTENTION_CLASSES = { + "rel_selfattn": RelPositionMultiHeadedAttention, +} + diff --git a/paddlespeech/t2s/models/CosyVoice/common.py b/paddlespeech/t2s/models/CosyVoice/common.py new file mode 100644 index 000000000..15cda97cc --- /dev/null +++ b/paddlespeech/t2s/models/CosyVoice/common.py @@ -0,0 +1,222 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import queue +import random +from typing import List + +import numpy as np + +def device2str(type=None, index=None, *, device=None): + type = device if device else type + if isinstance(type, int): + type = f'gpu:{type}' + elif isinstance(type, str): + if 'cuda' in type: + type = type.replace('cuda', 'gpu') + if 'cpu' in type: + type = 'cpu' + elif index is not None: + type = f'{type}:{index}' + elif isinstance(type, paddle.CPUPlace) or (type is None): + type = 'cpu' + elif isinstance(type, paddle.CUDAPlace): + type = f'gpu:{type.get_device_id()}' + + return type + + +IGNORE_ID = -1 + + +def pad_list(xs: List[paddle.Tensor], pad_value: int): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + max_len = max([len(item) for item in xs]) + batchs = len(xs) + ndim = xs[0].ndim + if ndim == 1: + pad_res = paddle.zeros(batchs, max_len, dtype=xs[0].dtype, device=xs[0].place) + elif ndim == 2: + pad_res = paddle.zeros( + batchs, max_len, xs[0].shape[1], dtype=xs[0].dtype, device=xs[0].place + ) + elif ndim == 3: + pad_res = paddle.zeros( + batchs, + max_len, + xs[0].shape[1], + xs[0].shape[2], + dtype=xs[0].dtype, + device=xs[0].place, + ) + else: + raise ValueError(f"Unsupported ndim: {ndim}") + pad_res.fill_(pad_value) + for i in range(batchs): + pad_res[i, : len(xs[i])] = xs[i] + return pad_res + + +def th_accuracy( + pad_outputs: paddle.Tensor, pad_targets: paddle.Tensor, ignore_label: int +) -> paddle.Tensor: + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax). + ignore_label (int): Ignore label id. + + Returns: + torch.Tensor: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) + ).argmax(2) + mask = pad_targets != ignore_label + numerator = paddle.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = paddle.sum(mask) + return (numerator / denominator).detach() + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def ras_sampling( + weighted_scores, + decoded_tokens, + sampling, + top_p=0.8, + top_k=25, + win_size=10, + tau_r=0.1, +): + top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) + rep_num = ( + (paddle.to_tensor(decoded_tokens[-win_size:],dtype = paddle.long).to(weighted_scores.place) == top_ids) + .sum() + .item() + ) + if rep_num >= win_size * tau_r: + top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)[0] + return top_ids + + +def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): + prob, indices = [], [] + cum_prob = 0.0 + sorted_value, sorted_idx = paddle.sort( + descending=True, stable=True, x=weighted_scores.softmax(axis=0) + ), paddle.argsort(descending=True, stable=True, x=weighted_scores.softmax(axis=0)) + + for i in range(len(sorted_idx)): + if cum_prob < top_p and len(prob) < top_k: + cum_prob += sorted_value[i] + prob.append(sorted_value[i]) + indices.append(sorted_idx[i]) + else: + break + prob = paddle.to_tensor(prob).cuda() + indices = paddle.to_tensor(indices, dtype=paddle.long).to(weighted_scores.place) + # top_ids = indices[prob.multinomial(num_samples=1, replacement=True)] + top_ids = indices[0] + return top_ids + + +def random_sampling(weighted_scores, decoded_tokens, sampling): + top_ids = weighted_scores.softmax(axis=0).multinomial( + num_samples=1, replacement=True + ) + return top_ids + + +def fade_in_out(fade_in_mel, fade_out_mel, window): + device = fade_in_mel.place + fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() + mel_overlap_len = int(window.shape[0] / 2) + if fade_in_mel.place == device2str("cpu"): + fade_in_mel = fade_in_mel.clone() + fade_in_mel[..., :mel_overlap_len] = ( + fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + + fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] + ) + return fade_in_mel.to(device) + + +def set_all_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + paddle.seed(seed) + + +def mask_to_bias(mask: paddle.Tensor, dtype: paddle.dtype) -> paddle.Tensor: + assert mask.dtype == paddle.bool + assert dtype in [paddle.float32, paddle.bfloat16, paddle.float16] + mask = mask.to(dtype) + mask = (1.0 - mask) * -10000000000.0 + return mask + + +class TrtContextWrapper: + def __init__(self, trt_engine, trt_concurrent=1, device="cuda:0"): + self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) + self.trt_engine = trt_engine + for _ in range(trt_concurrent): + trt_context = trt_engine.create_execution_context() + trt_stream = paddle.device.stream_guard( + paddle.device.Stream(device=device2str(device)) + ) + assert ( + trt_context is not None + ), "failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}".format( + trt_concurrent + ) + self.trt_context_pool.put([trt_context, trt_stream]) + assert self.trt_context_pool.empty() is False, "no avaialbe estimator context" + + def acquire_estimator(self): + return self.trt_context_pool.get(), self.trt_engine + + def release_estimator(self, context, stream): + self.trt_context_pool.put([context, stream]) \ No newline at end of file diff --git a/paddlespeech/t2s/models/CosyVoice/cosyvoice.py b/paddlespeech/t2s/models/CosyVoice/cosyvoice.py new file mode 100644 index 000000000..d463e0096 --- /dev/null +++ b/paddlespeech/t2s/models/CosyVoice/cosyvoice.py @@ -0,0 +1,374 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from typing import Generator + +import paddle +from hyperpyyaml import load_hyperpyyaml +from modelscope import snapshot_download +import logging +logging.getLogger('matplotlib').setLevel(logging.WARNING) +logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(message)s') +from paddlespeech.t2s.models.CosyVoice.frontend import CosyVoiceFrontEnd +from paddlespeech.t2s.models.CosyVoice.model import CosyVoice2Model + +def get_model_type(configs): + # NOTE CosyVoice2Model inherits CosyVoiceModel + if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): + return CosyVoiceModel + if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): + return CosyVoice2Model + raise TypeError('No valid model type found!') +class CosyVoice: + def __init__( + self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1 + ): + self.instruct = True if "-Instruct" in model_dir else False + self.model_dir = model_dir + self.fp16 = fp16 + if not os.path.exists(model_dir): + model_dir = snapshot_download(model_dir) + hyper_yaml_path = "{}/cosyvoice.yaml".format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError("{} not found!".format(hyper_yaml_path)) + with open(hyper_yaml_path, "r") as f: + configs = load_hyperpyyaml(f) + # assert ( + # get_model_type(configs) != CosyVoice2Model + # ), "do not use {} for CosyVoice initialization!".format(model_dir) + self.frontend = CosyVoiceFrontEnd( + configs["get_tokenizer"], + configs["feat_extractor"], + "{}/campplus.onnx".format(model_dir), + "{}/speech_tokenizer_v1.onnx".format(model_dir), + "{}/spk2info.pt".format(model_dir), + configs["allowed_special"], + ) + self.sample_rate = configs["sample_rate"] + if (paddle.device.cuda.device_count() >= 1) is False and ( + load_jit is True or load_trt is True or fp16 is True + ): + load_jit, load_trt, fp16 = False, False, False + logging.warning("no cuda device, set load_jit/load_trt/fp16 to False") + self.model = CosyVoiceModel( + configs["llm"], configs["flow"], configs["hift"], fp16 + ) + self.model.load( + "{}/llm.pt".format(model_dir), + "{}/flow.pt".format(model_dir), + "{}/hift.pt".format(model_dir), + ) + if load_jit: + self.model.load_jit( + "{}/llm.text_encoder.{}.zip".format( + model_dir, "fp16" if self.fp16 is True else "fp32" + ), + "{}/llm.llm.{}.zip".format( + model_dir, "fp16" if self.fp16 is True else "fp32" + ), + "{}/flow.encoder.{}.zip".format( + model_dir, "fp16" if self.fp16 is True else "fp32" + ), + ) + if load_trt: + self.model.load_trt( + "{}/flow.decoder.estimator.{}.mygpu.plan".format( + model_dir, "fp16" if self.fp16 is True else "fp32" + ), + "{}/flow.decoder.estimator.fp32.onnx".format(model_dir), + trt_concurrent, + self.fp16, + ) + del configs + + def list_available_spks(self): + spks = list(self.frontend.spk2info.keys()) + return spks + + def add_zero_shot_spk(self, prompt_text, prompt_speech_16k, zero_shot_spk_id): + assert zero_shot_spk_id != "", "do not use empty zero_shot_spk_id" + model_input = self.frontend.frontend_zero_shot( + "", prompt_text, prompt_speech_16k, self.sample_rate, "" + ) + del model_input["text"] + del model_input["text_len"] + self.frontend.spk2info[zero_shot_spk_id] = model_input + return True + + def save_spkinfo(self): + paddle.save( + obj=self.frontend.spk2info, path="{}/spk2info.pt".format(self.model_dir) + ) + + def inference_sft( + self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True + ): + for i in tqdm( + self.frontend.text_normalize( + tts_text, split=True, text_frontend=text_frontend + ) + ): + model_input = self.frontend.frontend_sft(i, spk_id) + start_time = time.time() + logging.info("synthesis text {}".format(i)) + for model_output in self.model.tts( + **model_input, stream=stream, speed=speed + ): + speech_len = model_output["tts_speech"].shape[1] / self.sample_rate + logging.info( + "yield speech len {}, rtf {}".format( + speech_len, (time.time() - start_time) / speech_len + ) + ) + yield model_output + start_time = time.time() + + def inference_zero_shot( + self, + tts_text, + prompt_text, + prompt_speech_16k, + zero_shot_spk_id="", + stream=False, + speed=1.0, + text_frontend=True, + ): + prompt_text = self.frontend.text_normalize( + prompt_text, split=False, text_frontend=text_frontend + ) + for i in tqdm( + self.frontend.text_normalize( + tts_text, split=True, text_frontend=text_frontend + ) + ): + if not isinstance(i, Generator) and len(i) < 0.5 * len(prompt_text): + logging.warning( + "synthesis text {} too short than prompt text {}, this may lead to bad performance".format( + i, prompt_text + ) + ) + model_input = self.frontend.frontend_zero_shot( + i, prompt_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id + ) + start_time = time.time() + logging.info("synthesis text {}".format(i)) + for model_output in self.model.tts( + **model_input, stream=stream, speed=speed + ): + speech_len = model_output["tts_speech"].shape[1] / self.sample_rate + logging.info( + "yield speech len {}, rtf {}".format( + speech_len, (time.time() - start_time) / speech_len + ) + ) + yield model_output + start_time = time.time() + + def inference_cross_lingual( + self, + tts_text, + prompt_speech_16k, + zero_shot_spk_id="", + stream=False, + speed=1.0, + text_frontend=True, + ): + for i in tqdm( + self.frontend.text_normalize( + tts_text, split=True, text_frontend=text_frontend + ) + ): + model_input = self.frontend.frontend_cross_lingual( + i, prompt_speech_16k, self.sample_rate, zero_shot_spk_id + ) + start_time = time.time() + logging.info("synthesis text {}".format(i)) + for model_output in self.model.tts( + **model_input, stream=stream, speed=speed + ): + speech_len = model_output["tts_speech"].shape[1] / self.sample_rate + logging.info( + "yield speech len {}, rtf {}".format( + speech_len, (time.time() - start_time) / speech_len + ) + ) + yield model_output + start_time = time.time() + + def inference_instruct( + self, + tts_text, + spk_id, + instruct_text, + stream=False, + speed=1.0, + text_frontend=True, + ): + assert isinstance( + self.model, CosyVoiceModel + ), "inference_instruct is only implemented for CosyVoice!" + if self.instruct is False: + raise ValueError( + "{} do not support instruct inference".format(self.model_dir) + ) + instruct_text = self.frontend.text_normalize( + instruct_text, split=False, text_frontend=text_frontend + ) + for i in tqdm( + self.frontend.text_normalize( + tts_text, split=True, text_frontend=text_frontend + ) + ): + model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) + start_time = time.time() + logging.info("synthesis text {}".format(i)) + for model_output in self.model.tts( + **model_input, stream=stream, speed=speed + ): + speech_len = model_output["tts_speech"].shape[1] / self.sample_rate + logging.info( + "yield speech len {}, rtf {}".format( + speech_len, (time.time() - start_time) / speech_len + ) + ) + yield model_output + start_time = time.time() + + def inference_vc( + self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0 + ): + model_input = self.frontend.frontend_vc( + source_speech_16k, prompt_speech_16k, self.sample_rate + ) + start_time = time.time() + for model_output in self.model.tts(**model_input, stream=stream, speed=speed): + speech_len = model_output["tts_speech"].shape[1] / self.sample_rate + logging.info( + "yield speech len {}, rtf {}".format( + speech_len, (time.time() - start_time) / speech_len + ) + ) + yield model_output + start_time = time.time() + + +class CosyVoice2(CosyVoice): + def __init__( + self, + model_dir, + load_jit=False, + load_trt=False, + load_vllm=False, + fp16=False, + trt_concurrent=1, + ): + self.instruct = True if "-Instruct" in model_dir else False + self.model_dir = model_dir + self.fp16 = fp16 + hyper_yaml_path = "{}/cosyvoice2.yaml".format(model_dir) + if not os.path.exists(hyper_yaml_path): + raise ValueError("{} not found!".format(hyper_yaml_path)) + with open(hyper_yaml_path, "r") as f: + configs = load_hyperpyyaml( + f, + overrides={ + "qwen_pretrain_path": os.path.join(model_dir, "CosyVoice-BlankEN") + }, + ) + # assert ( + # get_model_type(configs) == CosyVoice2Model + # ), "do not use {} for CosyVoice2 initialization!".format(model_dir) + self.frontend = CosyVoiceFrontEnd( + configs["get_tokenizer"], + configs["feat_extractor"], + "{}/campplus.onnx".format(model_dir), + "{}/speech_tokenizer_v2.onnx".format(model_dir), + "{}/spk2info.pt".format(model_dir), + configs["allowed_special"], + ) + self.sample_rate = configs["sample_rate"] + if (paddle.device.cuda.device_count() >= 1) is False and ( + load_jit is True or load_trt is True or fp16 is True + ): + load_jit, load_trt, fp16 = False, False, False + logging.warning("no cuda device, set load_jit/load_trt/fp16 to False") + self.model = CosyVoice2Model( + configs["llm"], configs["flow"], configs["hift"], fp16 + ) + self.model.load( + "{}/llm.pt".format(model_dir), + "{}/flow.pt".format(model_dir), + "{}/hift.pt".format(model_dir), + ) + if load_vllm: + self.model.load_vllm("{}/vllm".format(model_dir)) + if load_jit: + self.model.load_jit( + "{}/flow.encoder.{}.zip".format( + model_dir, "fp16" if self.fp16 is True else "fp32" + ) + ) + if load_trt: + self.model.load_trt( + "{}/flow.decoder.estimator.{}.mygpu.plan".format( + model_dir, "fp16" if self.fp16 is True else "fp32" + ), + "{}/flow.decoder.estimator.fp32.onnx".format(model_dir), + trt_concurrent, + self.fp16, + ) + del configs + + def inference_instruct(self, *args, **kwargs): + raise NotImplementedError( + "inference_instruct is not implemented for CosyVoice2!" + ) + + def inference_instruct2( + self, + tts_text, + instruct_text, + prompt_speech_16k, + zero_shot_spk_id="", + stream=False, + speed=1.0, + text_frontend=True, + ): + assert isinstance( + self.model, CosyVoice2Model + ), "inference_instruct2 is only implemented for CosyVoice2!" + for i in tqdm( + self.frontend.text_normalize( + tts_text, split=True, text_frontend=text_frontend + ) + ): + model_input = self.frontend.frontend_instruct2( + i, instruct_text, prompt_speech_16k, self.sample_rate, zero_shot_spk_id + ) + start_time = time.time() + logging.info("synthesis text {}".format(i)) + for model_output in self.model.tts( + **model_input, stream=stream, speed=speed + ): + speech_len = model_output["tts_speech"].shape[1] / self.sample_rate + logging.info( + "yield speech len {}, rtf {}".format( + speech_len, (time.time() - start_time) / speech_len + ) + ) + yield model_output + start_time = time.time() diff --git a/paddlespeech/t2s/models/CosyVoice/flow.py b/paddlespeech/t2s/models/CosyVoice/flow.py new file mode 100644 index 000000000..c5529a5ac --- /dev/null +++ b/paddlespeech/t2s/models/CosyVoice/flow.py @@ -0,0 +1,267 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any +from typing import Dict +from typing import List + +import paddle +from paddle import nn +from paddle.nn import functional as F + +class Decoder(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + + self.down_blocks = nn.ModuleList([]) + self.mid_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + output_channel = in_channels + for i in range(len(channels)): # pylint: disable=consider-using-enumerate + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) + + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.ModuleList( + [ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + + resnet = ResnetBlock1D( + dim=2 * input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.ModuleList( + [ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + + self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) + + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) + + self.initialize_weights() + # nn.init.normal_(self.final_proj.weight) + + @staticmethod + def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + + return block + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + + t = self.time_embeddings(t) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c") + mask_down = rearrange(mask_down, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_down, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_down = rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c") + mask_mid = rearrange(mask_mid, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_mid, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_mid = rearrange(mask_mid, "b t -> b 1 t") + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) + x = rearrange(x, "b c t -> b t c") + mask_up = rearrange(mask_up, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=mask_up, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t") + mask_up = rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + + return output * mask diff --git a/paddlespeech/t2s/models/CosyVoice/frontend.py b/paddlespeech/t2s/models/CosyVoice/frontend.py new file mode 100644 index 000000000..6aca2489e --- /dev/null +++ b/paddlespeech/t2s/models/CosyVoice/frontend.py @@ -0,0 +1,459 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os +import re +from functools import partial +from typing import Callable, Generator + +import inflect +import numpy as np +import onnxruntime +import paddle +import paddlespeech +import whisper +import logging +try: + import ttsfrd + + use_ttsfrd = True +except ImportError: + print("failed to import ttsfrd, use wetext instead") + from wetext import Normalizer as EnNormalizer + from wetext import Normalizer as ZhNormalizer + + use_ttsfrd = False +# split paragrah logic: +# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len +# 2. cal sentence len according to lang +# 3. split sentence according to puncatation +def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): + def calc_utt_length(_text: str): + if lang == "zh": + return len(_text) + else: + return len(tokenize(_text)) + + def should_merge(_text: str): + if lang == "zh": + return len(_text) < merge_len + else: + return len(tokenize(_text)) < merge_len + + if lang == "zh": + pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] + else: + pounc = ['.', '?', '!', ';', ':'] + if comma_split: + pounc.extend([',', ',']) + + if text[-1] not in pounc: + if lang == "zh": + text += "。" + else: + text += "." + + st = 0 + utts = [] + for i, c in enumerate(text): + if c in pounc: + if len(text[st: i]) > 0: + utts.append(text[st: i] + c) + if i + 1 < len(text) and text[i + 1] in ['"', '”']: + tmp = utts.pop(-1) + utts.append(tmp + text[i + 1]) + st = i + 2 + else: + st = i + 1 + + final_utts = [] + cur_utt = "" + for utt in utts: + if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: + final_utts.append(cur_utt) + cur_utt = "" + cur_utt = cur_utt + utt + if len(cur_utt) > 0: + if should_merge(cur_utt) and len(final_utts) != 0: + final_utts[-1] = final_utts[-1] + cur_utt + else: + final_utts.append(cur_utt) + + return final_utts + +# spell Arabic numerals +def spell_out_number(text: str, inflect_parser): + new_text = [] + st = None + for i, c in enumerate(text): + if not c.isdigit(): + if st is not None: + num_str = inflect_parser.number_to_words(text[st: i]) + new_text.append(num_str) + st = None + new_text.append(c) + else: + if st is None: + st = i + if st is not None and st < len(text): + num_str = inflect_parser.number_to_words(text[st:]) + new_text.append(num_str) + return ''.join(new_text) + +# replace special symbol +def replace_corner_mark(text): + text = text.replace('²', '平方') + text = text.replace('³', '立方') + return text + +# remove blank between chinese character +def replace_blank(text: str): + out_str = [] + for i, c in enumerate(text): + if c == " ": + if ((text[i + 1].isascii() and text[i + 1] != " ") and + (text[i - 1].isascii() and text[i - 1] != " ")): + out_str.append(c) + else: + out_str.append(c) + return "".join(out_str) +def is_only_punctuation(text): + # Regular expression: Match strings that consist only of punctuation marks or are empty. + punctuation_pattern = r'^[\p{P}\p{S}]*$' + return bool(regex.fullmatch(punctuation_pattern, text)) + +# remove meaningless symbol +def remove_bracket(text): + text = text.replace('(', '').replace(')', '') + text = text.replace('【', '').replace('】', '') + text = text.replace('`', '').replace('`', '') + text = text.replace("——", " ") + return text +class CosyVoiceFrontEnd: + def __init__( + self, + get_tokenizer: Callable, + feat_extractor: Callable, + campplus_model: str, + speech_tokenizer_model: str, + spk2info: str = "", + allowed_special: str = "all", + ): + self.tokenizer = get_tokenizer() + self.feat_extractor = feat_extractor + self.device = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() + option = onnxruntime.SessionOptions() + option.graph_optimization_level = ( + onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + ) + option.intra_op_num_threads = 1 + self.campplus_session = onnxruntime.InferenceSession( + campplus_model, sess_options=option, providers=["CPUExecutionProvider"] + ) + self.speech_tokenizer_session = onnxruntime.InferenceSession( + speech_tokenizer_model, + sess_options=option, + providers=[ + "CUDAExecutionProvider" + if paddle.device.cuda.device_count() >= 1 + else "CPUExecutionProvider" + ], + ) + if os.path.exists(spk2info): + self.spk2info = paddle.load(path=str(spk2info)) + else: + self.spk2info = {} + self.allowed_special = allowed_special + self.use_ttsfrd = use_ttsfrd + if self.use_ttsfrd: + self.frd = ttsfrd.TtsFrontendEngine() + ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + assert ( + self.frd.initialize( + "{}/../../pretrained_models/CosyVoice-ttsfrd/resource".format( + ROOT_DIR + ) + ) + is True + ), "failed to initialize ttsfrd resource" + self.frd.set_lang_type("pinyinvg") + else: + self.zh_tn_model = ZhNormalizer(remove_erhua=False) + self.en_tn_model = EnNormalizer() + self.inflect_parser = inflect.engine() + + def _extract_text_token(self, text): + if isinstance(text, Generator): + logging.info( + "get tts_text generator, will return _extract_text_token_generator!" + ) + return self._extract_text_token_generator(text), paddle.tensor( + [0], dtype=paddle.int32 + ).to(self.device) + else: + text_token = self.tokenizer.encode( + text, allowed_special=self.allowed_special + ) + text_token = paddle.tensor([text_token], dtype=paddle.int32).to(self.device) + text_token_len = paddle.tensor( + [text_token.shape[1]], dtype=paddle.int32 + ).to(self.device) + return text_token, text_token_len + + def _extract_text_token_generator(self, text_generator): + for text in text_generator: + text_token, _ = self._extract_text_token(text) + for i in range(text_token.shape[1]): + yield text_token[:, i : i + 1] + + def _extract_speech_token(self, speech): + assert ( + speech.shape[1] / 16000 <= 30 + ), "do not support extract speech token for audio longer than 30s" + feat = whisper.log_mel_spectrogram(speech, n_mels=128) + speech_token = ( + self.speech_tokenizer_session.run( + None, + { + self.speech_tokenizer_session.get_inputs()[0] + .name: feat.detach() + .cpu() + .numpy(), + self.speech_tokenizer_session.get_inputs()[1].name: np.array( + [feat.shape[2]], dtype=np.int32 + ), + }, + )[0] + .flatten() + .tolist() + ) + speech_token = paddle.tensor([speech_token], dtype=paddle.int32).to(self.device) + speech_token_len = paddle.tensor( + [speech_token.shape[1]], dtype=paddle.int32 + ).to(self.device) + return speech_token, speech_token_len + + def _extract_spk_embedding(self, speech): + ##################>>>>>>>>>>>>>>>>>>> + feat = torchaudio.compliance.kaldi.fbank( + speech, num_mel_bins=80, dither=0, sample_frequency=16000 + ) + ##################>>>>>>>>>>>>>>>>>>> + feat = feat - feat.mean(dim=0, keepdim=True) + embedding = ( + self.campplus_session.run( + None, + { + self.campplus_session.get_inputs()[0] + .name: feat.unsqueeze(dim=0) + .cpu() + .numpy() + }, + )[0] + .flatten() + .tolist() + ) + embedding = paddle.tensor([embedding]).to(self.device) + return embedding + + def _extract_speech_feat(self, speech): + speech_feat = ( + self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device) + ) + speech_feat = speech_feat.unsqueeze(dim=0) + speech_feat_len = paddle.tensor([speech_feat.shape[1]], dtype=paddle.int32).to( + self.device + ) + return speech_feat, speech_feat_len + + def text_normalize(self, text, split=True, text_frontend=True): + if isinstance(text, Generator): + logging.info("get tts_text generator, will skip text_normalize!") + return [text] + if text_frontend is False or text == "": + return [text] if split is True else text + text = text.strip() + if self.use_ttsfrd: + texts = [ + i["text"] + for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"] + ] + text = "".join(texts) + elif contains_chinese(text): + text = self.zh_tn_model.normalize(text) + text = text.replace("\n", "") + text = replace_blank(text) + text = replace_corner_mark(text) + text = text.replace(".", "。") + text = text.replace(" - ", ",") + text = remove_bracket(text) + text = re.sub("[,,、]+$", "。", text) + texts = list( + split_paragraph( + text, + partial( + self.tokenizer.encode, allowed_special=self.allowed_special + ), + "zh", + token_max_n=80, + token_min_n=60, + merge_len=20, + comma_split=False, + ) + ) + else: + text = self.en_tn_model.normalize(text) + text = spell_out_number(text, self.inflect_parser) + texts = list( + split_paragraph( + text, + partial( + self.tokenizer.encode, allowed_special=self.allowed_special + ), + "en", + token_max_n=80, + token_min_n=60, + merge_len=20, + comma_split=False, + ) + ) + texts = [i for i in texts if not is_only_punctuation(i)] + return texts if split is True else text + + def frontend_sft(self, tts_text, spk_id): + tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) + embedding = self.spk2info[spk_id]["embedding"] + model_input = { + "text": tts_text_token, + "text_len": tts_text_token_len, + "llm_embedding": embedding, + "flow_embedding": embedding, + } + return model_input + + def frontend_zero_shot( + self, tts_text, prompt_text, prompt_speech_16k, resample_rate, zero_shot_spk_id + ): + tts_text_token, tts_text_token_len = self._extract_text_token(tts_text) + if zero_shot_spk_id == "": + prompt_text_token, prompt_text_token_len = self._extract_text_token( + prompt_text + ) + #>>>>>>>>>>>>>>>>>>> + prompt_speech_resample = torchaudio.transforms.Resample( + orig_freq=16000, new_freq=resample_rate + )(prompt_speech_16k) + #>>>>>>>>>>>>>>>>>>> + speech_feat, speech_feat_len = self._extract_speech_feat( + prompt_speech_resample + ) + speech_token, speech_token_len = self._extract_speech_token( + prompt_speech_16k + ) + if resample_rate == 24000: + token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1]) + speech_feat, speech_feat_len[:] = ( + speech_feat[:, : 2 * token_len], + 2 * token_len, + ) + speech_token, speech_token_len[:] = ( + speech_token[:, :token_len], + token_len, + ) + embedding = self._extract_spk_embedding(prompt_speech_16k) + model_input = { + "prompt_text": prompt_text_token, + "prompt_text_len": prompt_text_token_len, + "llm_prompt_speech_token": speech_token, + "llm_prompt_speech_token_len": speech_token_len, + "flow_prompt_speech_token": speech_token, + "flow_prompt_speech_token_len": speech_token_len, + "prompt_speech_feat": speech_feat, + "prompt_speech_feat_len": speech_feat_len, + "llm_embedding": embedding, + "flow_embedding": embedding, + } + else: + model_input = self.spk2info[zero_shot_spk_id] + model_input["text"] = tts_text_token + model_input["text_len"] = tts_text_token_len + return model_input + + def frontend_cross_lingual( + self, tts_text, prompt_speech_16k, resample_rate, zero_shot_spk_id + ): + model_input = self.frontend_zero_shot( + tts_text, "", prompt_speech_16k, resample_rate, zero_shot_spk_id + ) + del model_input["prompt_text"] + del model_input["prompt_text_len"] + del model_input["llm_prompt_speech_token"] + del model_input["llm_prompt_speech_token_len"] + return model_input + + def frontend_instruct(self, tts_text, spk_id, instruct_text): + model_input = self.frontend_sft(tts_text, spk_id) + del model_input["llm_embedding"] + instruct_text_token, instruct_text_token_len = self._extract_text_token( + instruct_text + "" + ) + model_input["prompt_text"] = instruct_text_token + model_input["prompt_text_len"] = instruct_text_token_len + return model_input + + def frontend_instruct2( + self, + tts_text, + instruct_text, + prompt_speech_16k, + resample_rate, + zero_shot_spk_id, + ): + model_input = self.frontend_zero_shot( + tts_text, + instruct_text + "<|endofprompt|>", + prompt_speech_16k, + resample_rate, + zero_shot_spk_id, + ) + del model_input["llm_prompt_speech_token"] + del model_input["llm_prompt_speech_token_len"] + return model_input + + def frontend_vc(self, source_speech_16k, prompt_speech_16k, resample_rate): + prompt_speech_token, prompt_speech_token_len = self._extract_speech_token( + prompt_speech_16k + ) + #>>>>>>>>>>>>>>>>>> + prompt_speech_resample = torchaudio.transforms.Resample( + orig_freq=16000, new_freq=resample_rate + )(prompt_speech_16k) + #>>>>>>>>>>>>>>>>>> + prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat( + prompt_speech_resample + ) + embedding = self._extract_spk_embedding(prompt_speech_16k) + source_speech_token, source_speech_token_len = self._extract_speech_token( + source_speech_16k + ) + model_input = { + "source_speech_token": source_speech_token, + "source_speech_token_len": source_speech_token_len, + "flow_prompt_speech_token": prompt_speech_token, + "flow_prompt_speech_token_len": prompt_speech_token_len, + "prompt_speech_feat": prompt_speech_feat, + "prompt_speech_feat_len": prompt_speech_feat_len, + "flow_embedding": embedding, + } + return model_input \ No newline at end of file diff --git a/paddlespeech/t2s/models/CosyVoice/llm.py b/paddlespeech/t2s/models/CosyVoice/llm.py new file mode 100644 index 000000000..ae35ff0af --- /dev/null +++ b/paddlespeech/t2s/models/CosyVoice/llm.py @@ -0,0 +1,735 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +import random +import threading +import time +from typing import Callable, Dict, Generator, List, Optional +import logging +import paddle.nn.functional as F +import paddle +IGNORE_ID = -1 +import torch +LabelSmoothingLoss = None +def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): + top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) + recent_tokens = paddle.to_tensor(decoded_tokens[-win_size:], dtype='int64') + rep_num = paddle.sum(recent_tokens.cpu() == top_ids.cpu()).cpu().item() + if rep_num >= win_size * tau_r: + top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) + return top_ids + + +def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): + softmax_scores = paddle.nn.functional.softmax(weighted_scores, axis=0) + sorted_indices = paddle.argsort(softmax_scores, axis=0, descending=True) + sorted_probs = paddle.gather(softmax_scores, sorted_indices, axis=0) + + prob_list = [] + indices_list = [] + cum_prob = 0.0 + + for i in range(len(sorted_indices)): + if cum_prob < top_p and len(prob_list) < top_k: + cum_prob += sorted_probs[i].item() + prob_list.append(sorted_probs[i]) + indices_list.append(sorted_indices[i]) + else: + break + + prob_tensor = paddle.to_tensor(prob_list, dtype=weighted_scores.dtype) + indices_tensor = paddle.to_tensor(indices_list, dtype='int64') + top_ids = indices_tensor[paddle.multinomial(prob_tensor, num_samples=1, replacement=True)] + + return top_ids + + +def random_sampling(weighted_scores, decoded_tokens, sampling): + probs = paddle.nn.functional.softmax(weighted_scores, axis=0) + top_ids = paddle.multinomial(probs, num_samples=1, replacement=True) + return top_ids +def make_pad_mask(lengths: paddle.Tensor, max_len: int = 0) -> paddle.Tensor: + batch_size = lengths.shape[0] + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = paddle.arange(0, max_len, dtype='int64') + seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len]) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + +def th_accuracy(pad_outputs: paddle.Tensor, pad_targets: paddle.Tensor, + ignore_label: int) -> paddle.Tensor: + pad_pred = pad_outputs.reshape((pad_targets.shape[0], pad_targets.shape[1], -1)).argmax(axis=2) + mask = pad_targets != ignore_label + numerator = paddle.sum((pad_pred[mask] == pad_targets[mask]).astype('float32')) + denominator = paddle.sum(mask.astype('float32')) + accuracy = numerator / denominator + + return accuracy.detach() +class TransformerLM(paddle.nn.Layer): + def __init__( + self, + text_encoder_input_size: int, + llm_input_size: int, + llm_output_size: int, + text_token_size: int, + speech_token_size: int, + text_encoder: paddle.nn.Layer, + llm: paddle.nn.Layer, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + spk_embed_dim: int = 192, + ): + super().__init__() + self.llm_input_size = llm_input_size + self.speech_token_size = speech_token_size + self.text_embedding = paddle.nn.Embedding( + text_token_size, text_encoder_input_size + ) + self.text_encoder = text_encoder + self.text_encoder_affine_layer = paddle.nn.Linear( + in_features=self.text_encoder.output_size(), out_features=llm_input_size + ) + self.sos_eos = 0 + self.task_id = 1 + self.llm_embedding = paddle.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = paddle.nn.Linear( + in_features=llm_output_size, out_features=speech_token_size + 1 + ) + + self.criterion_ce = LabelSmoothingLoss( + size=speech_token_size + 1, + padding_idx=IGNORE_ID, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + self.speech_embedding = paddle.nn.Embedding(speech_token_size, llm_input_size) + self.spk_embed_affine_layer = paddle.nn.Linear( + in_features=spk_embed_dim, out_features=llm_input_size + ) + self.sampling = sampling + + def encode(self, text: paddle.Tensor, text_lengths: paddle.Tensor): + encoder_out, encoder_mask = self.text_encoder( + text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1 + ) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out = self.text_encoder_affine_layer(encoder_out) + return encoder_out, encoder_out_lens + + def pad_unpad_sequence( + self, + sos_eos_emb, + embedding, + text_token, + text_token_len, + task_id_emb, + speech_token, + speech_token_len, + ): + + text_token = paddle.static.nn.sequence_unpad( + text_token, text_token_len.cpu() + ) + speech_token = paddle.static.nn.sequence_unpad( + speech_token, speech_token_len.cpu() + ) + lm_input = [ + paddle.cat( + [ + sos_eos_emb.squeeze(dim=0), + embedding[i], + text_token[i], + task_id_emb.squeeze(dim=0), + speech_token[i], + ], + dim=0, + ) + for i in range(len(text_token)) + ] + lm_input_len = paddle.tensor([i.size(0) for i in lm_input], dtype=paddle.int32) + lm_input = paddle.static.nn.sequence_unpad( + lm_input, batch_first=True, padding_value=IGNORE_ID + ) + return lm_input, lm_input_len + + def forward( + self, batch: dict, device: torch.device + ) -> Dict[str, Optional[paddle.Tensor]]: + """ + Args: + text: (B, L, D) + text_lengths: (B,) + audio: (B, T, N) or (B, T) + audio_lengths: (B,) + """ + text_token = batch["text_token"].to(device) + text_token_len = batch["text_token_len"].to(device) + speech_token = batch["speech_token"].to(device) + speech_token_len = batch["speech_token_len"].to(device) + embedding = batch["embedding"].to(device) + lm_target = [ + paddle.tensor( + [IGNORE_ID] * (2 + text_token_len[i]) + + speech_token[i, : speech_token_len[i]].tolist() + + [self.speech_token_size] + ) + for i in range(text_token.size(0)) + ] + lm_target = torch.nn.utils.rnn.pad_sequence( + lm_target, batch_first=True, padding_value=IGNORE_ID + ).to(device) + text_token = self.text_embedding(text_token) + text_token, text_token_len = self.encode(text_token, text_token_len) + embedding = paddle.nn.functional.normalize(x=embedding, axis=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(1) + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + speech_token = self.speech_embedding(speech_token) + lm_input, lm_input_len = self.pad_unpad_sequence( + sos_eos_emb, + embedding, + text_token, + text_token_len, + task_id_emb, + speech_token, + speech_token_len, + ) + lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device)) + logits = self.llm_decoder(lm_output) + loss = self.criterion_ce(logits, lm_target) + acc = th_accuracy( + logits.view(-1, self.speech_token_size + 1), + lm_target, + ignore_label=IGNORE_ID, + ) + return {"loss": loss, "acc": acc} + + def sampling_ids( + self, + weighted_scores: paddle.Tensor, + decoded_tokens: List, + sampling: int, + ignore_eos: bool = True, + ): + num_trials, max_trials = 0, 100 + while True: + top_ids = self.sampling(weighted_scores, decoded_tokens, sampling) + if (not ignore_eos) or (top_ids < self.speech_token_size): + break + num_trials += 1 + if num_trials > max_trials: + raise RuntimeError( + "sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!".format( + max_trials + ) + ) + return top_ids + + @paddle.no_grad() + def inference( + self, + text: paddle.Tensor, + text_len: paddle.Tensor, + prompt_text: paddle.Tensor, + prompt_text_len: paddle.Tensor, + prompt_speech_token: paddle.Tensor, + prompt_speech_token_len: paddle.Tensor, + embedding: paddle.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + uuid: str = "", + ) -> Generator[paddle.Tensor, None, None]: + device = text.place + text = paddle.cat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.text_embedding(text) + text, text_len = self.encode(text, text_len) + if embedding.shape[0] != 0: + embedding = paddle.nn.functional.normalize(x=embedding, axis=1) + embedding = self.spk_embed_affine_layer(embedding) + embedding = embedding.unsqueeze(dim=1) + else: + embedding = ( + paddle.zeros(1, 0, self.llm_input_size, dtype=text.dtype) + .to(device) + .to(text.dtype) + ) + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = paddle.zeros( + 1, 0, self.llm_input_size, dtype=text.dtype + ).to(device) + lm_input = paddle.cat( + [sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1 + ) + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + out_tokens = [] + offset = 0 + att_cache, cnn_cache = paddle.zeros( + (0, 0, 0, 0), device=lm_input.place + ), paddle.zeros((0, 0, 0, 0), device=lm_input.place) + for i in range(max_len): + y_pred, att_cache, cnn_cache = self.llm.forward_chunk( + lm_input, + offset=offset, + required_cache_size=-1, + att_cache=att_cache, + cnn_cache=cnn_cache, + att_mask=paddle.tril( + paddle.ones( + (1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.place + ) + ).to(paddle.bool), + ) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + if i == 0: + logp[:, self.speech_token_size] = -float("inf") + top_ids = self.sampling_ids( + logp.squeeze(dim=0), + out_tokens, + sampling, + ignore_eos=True if i < min_len else False, + ).item() + if top_ids == self.speech_token_size: + break + yield top_ids + out_tokens.append(top_ids) + offset += lm_input.size(1) + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + +class Qwen2Encoder(paddle.nn.Layer): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, xs: paddle.Tensor, xs_lens: paddle.Tensor): + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T) + outs = self.model( + inputs_embeds=xs, + attention_mask=masks, + output_hidden_states=True, + return_dict=True, + ) + return outs.hidden_states[-1], masks.unsqueeze(1) + + def forward_one_step(self, xs, masks, cache=None,idx = 0): + + input_masks = masks[:, -1, :] + outs = self.model( + inputs_embeds=xs, + attention_mask=input_masks, + output_hidden_states=True, + return_dict=True, + output_attentions=False, + use_cache=True, + past_key_values=cache, + index =idx + ) + xs = outs.hidden_states[-1] + new_cache = outs.past_key_values + xs = paddle.cast(xs, dtype = 'float32') + return xs, new_cache + + +class Qwen2LM(TransformerLM): + def __init__( + self, + llm_input_size: int, + llm_output_size: int, + speech_token_size: int, + llm: paddle.nn.Layer, + sampling: Callable, + length_normalized_loss: bool = True, + lsm_weight: float = 0.0, + mix_ratio: List[int] = [5, 15], + ): + paddle.nn.Layer.__init__(self) + self.llm_input_size = llm_input_size + self.llm_output_size = llm_output_size + self.speech_token_size = speech_token_size + self.sos_eos = 0 + self.task_id = 1 + self.fill_token = 2 + self.llm_embedding = paddle.nn.Embedding(2, llm_input_size) + self.llm = llm + self.llm_decoder = paddle.nn.Linear( + in_features=llm_output_size, out_features=speech_token_size + 3 + ) + self.speech_embedding = paddle.nn.Embedding( + speech_token_size + 3, llm_input_size + ) + self.sampling = sampling + self.mix_ratio = mix_ratio + self.stop_token_ids = [(speech_token_size + i) for i in range(3)] + self.vllm_output_queue = {} + + def prepare_lm_input_target( + self, + text_token, + text_token_emb, + text_token_len, + speech_token, + speech_token_emb, + speech_token_len, + ): + lm_target, lm_input = [], [] + text_token = torch.nn.utils.rnn.unpad_sequence( + text_token, text_token_len.cpu(), batch_first=True + ) + speech_token = torch.nn.utils.rnn.unpad_sequence( + speech_token, speech_token_len.cpu(), batch_first=True + ) + text_token_emb = torch.nn.utils.rnn.unpad_sequence( + text_token_emb, text_token_len.cpu(), batch_first=True + ) + speech_token_emb = torch.nn.utils.rnn.unpad_sequence( + speech_token_emb, speech_token_len.cpu(), batch_first=True + ) + for i in range(len(text_token)): + if ( + random.random() < 0.5 + and speech_token_len[i] / text_token_len[i] + > self.mix_ratio[1] / self.mix_ratio[0] + ): + this_lm_target, this_lm_input = [], [] + this_lm_target.append(IGNORE_ID) + this_lm_input.append( + self.llm_embedding.weight[self.sos_eos].reshape(1, -1) + ) + for j in range( + ((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item() + ): + this_text_token = text_token[i][ + j * self.mix_ratio[0] : (j + 1) * self.mix_ratio[0] + ].tolist() + this_speech_token = speech_token[i][ + j * self.mix_ratio[1] : (j + 1) * self.mix_ratio[1] + ].tolist() + if len(this_text_token) == self.mix_ratio[0]: + assert len(this_speech_token) == self.mix_ratio[1] + this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1) + this_lm_target += this_speech_token + this_lm_target.append(self.speech_token_size + 2) + this_lm_input.append( + text_token_emb[i][ + j * self.mix_ratio[0] : (j + 1) * self.mix_ratio[0] + ] + ) + this_lm_input.append( + speech_token_emb[i][ + j * self.mix_ratio[1] : (j + 1) * self.mix_ratio[1] + ] + ) + else: + this_lm_target += [-1] * len(this_text_token) + this_lm_target += speech_token[i][ + j * self.mix_ratio[1] : + ].tolist() + this_lm_target.append(self.speech_token_size) + this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0] :]) + this_lm_input.append( + self.llm_embedding.weight[self.task_id].reshape(1, -1) + ) + this_lm_input.append( + speech_token_emb[i][j * self.mix_ratio[1] :] + ) + this_lm_target, this_lm_input = paddle.tensor( + this_lm_target + ), paddle.cat(this_lm_input, dim=0) + else: + this_lm_target = paddle.tensor( + [IGNORE_ID] * (1 + text_token_len[i]) + + speech_token[i].tolist() + + [self.speech_token_size] + ) + this_lm_input = paddle.cat( + [ + self.llm_embedding.weight[self.sos_eos].reshape(1, -1), + text_token_emb[i], + self.llm_embedding.weight[self.task_id].reshape(1, -1), + speech_token_emb[i], + ], + dim=0, + ) + lm_target.append(this_lm_target) + lm_input.append(this_lm_input) + lm_input_len = paddle.tensor([i.size(0) for i in lm_input], dtype=paddle.int32) + lm_input = torch.nn.utils.rnn.pad_sequence( + lm_input, batch_first=True, padding_value=IGNORE_ID + ) + lm_target = torch.nn.utils.rnn.pad_sequence( + lm_target, batch_first=True, padding_value=IGNORE_ID + ) + return lm_target, lm_input, lm_input_len + + @paddle.no_grad() + def inference( + self, + text: paddle.Tensor, + text_len: paddle.Tensor, + prompt_text: paddle.Tensor, + prompt_text_len: paddle.Tensor, + prompt_speech_token: paddle.Tensor, + prompt_speech_token_len: paddle.Tensor, + embedding: paddle.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + uuid: str = "", + ) -> Generator[paddle.Tensor, None, None]: + device = text.place + text = paddle.cat([prompt_text, text], dim=1) + text_len += prompt_text_len + text = self.llm.model.qwen2.embed_tokens(text) + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape([1, 1, -1]) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape([1, 1, -1]) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = paddle.zeros( + 1, 0, self.llm_input_size, dtype=text.dtype + ).to(device) + text = paddle.cast(text,dtype = 'float32') + lm_input = paddle.cat( + [sos_eos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1 + ) + min_len = int((text_len - prompt_text_len) * min_token_text_ratio) + max_len = int((text_len - prompt_text_len) * max_token_text_ratio) + for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid): + yield token + + @paddle.no_grad() + def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid): + if hasattr(self, "vllm"): + from vllm import RequestOutput, SamplingParams + + sampling_params = SamplingParams( + top_k=sampling, + stop_token_ids=self.stop_token_ids, + min_tokens=min_len, + max_tokens=max_len, + ) + with self.lock: + self.vllm.add_request( + uuid, + { + "prompt_embeds": lm_input.squeeze(0) + .to(paddle.bfloat16) + .to(lm_input.place) + }, + sampling_params, + ) + self.vllm_output_queue[uuid] = queue.Queue() + out_tokens = [] + while True: + with self.lock: + if self.vllm_output_queue[uuid].empty() is True: + request_outputs: List[RequestOutput] = self.vllm.step() + for request_output in request_outputs: + top_ids = list(request_output.outputs[0].token_ids)[-1] + self.vllm_output_queue[request_output.request_id].put( + top_ids + ) + if self.vllm_output_queue[uuid].empty() is False: + top_ids = self.vllm_output_queue[uuid].get() + if top_ids in self.stop_token_ids: + break + yield top_ids + out_tokens.append(top_ids) + if len(out_tokens) == max_len: + break + time.sleep(0.001) + with self.lock: + self.vllm_output_queue.pop(uuid) + else: + out_tokens = [] + cache = None + for i in range(max_len): + + y_pred, cache = self.llm.forward_one_step( + lm_input, + masks=paddle.tril( + paddle.ones( + (1, lm_input.shape[1], lm_input.shape[1]), + ) + ).to(paddle.bool), + cache=cache, + idx = i + ) + + logp = F.log_softmax(self.llm_decoder(y_pred[:, -1]), axis = -1) + top_ids = self.sampling_ids( + logp.squeeze(axis=0), + out_tokens, + sampling, + ignore_eos=True if i < min_len else False, + ) + if top_ids in self.stop_token_ids: + break + yield top_ids + out_tokens.append(top_ids) + lm_input = self.speech_embedding.weight[top_ids].reshape([1, 1, -1]) + @paddle.no_grad() + def inference_bistream( + self, + text: Generator, + prompt_text: paddle.Tensor, + prompt_text_len: paddle.Tensor, + prompt_speech_token: paddle.Tensor, + prompt_speech_token_len: paddle.Tensor, + embedding: paddle.Tensor, + sampling: int = 25, + max_token_text_ratio: float = 20, + min_token_text_ratio: float = 2, + ) -> Generator[paddle.Tensor, None, None]: + device = prompt_text.place + sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1) + task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1) + if prompt_speech_token_len != 0: + prompt_speech_token_emb = self.speech_embedding(prompt_speech_token) + else: + prompt_speech_token_emb = paddle.zeros( + 1, 0, self.llm_input_size, dtype=prompt_text.dtype + ).to(device) + lm_input = paddle.cat([sos_eos_emb], dim=1) + out_tokens = [] + cache = None + text_cache = self.llm.model.model.embed_tokens(prompt_text) + next_fill_index = -1 + for this_text in text: + text_cache = paddle.cat( + [text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1 + ) + while prompt_speech_token_emb.size(1) != 0: + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text, lm_input_speech = ( + text_cache[:, : self.mix_ratio[0]], + prompt_speech_token_emb[:, : self.mix_ratio[1]], + ) + logging.info( + "append {} text token {} speech token".format( + lm_input_text.size(1), lm_input_speech.size(1) + ) + ) + lm_input = paddle.cat( + [lm_input, lm_input_text, lm_input_speech], dim=1 + ) + text_cache, prompt_speech_token_emb = ( + text_cache[:, self.mix_ratio[0] :], + prompt_speech_token_emb[:, self.mix_ratio[1] :], + ) + else: + logging.info("not enough text token to decode, wait for more") + break + if prompt_speech_token_emb.size(1) == 0: + if ( + len(out_tokens) != 0 + and out_tokens[-1] == self.speech_token_size + 2 + or len(out_tokens) == 0 + and lm_input.size(1) == 1 + ): + logging.info("get fill token, need to append more text token") + if text_cache.size(1) >= self.mix_ratio[0]: + lm_input_text = text_cache[:, : self.mix_ratio[0]] + logging.info( + "append {} text token".format(lm_input_text.size(1)) + ) + if ( + len(out_tokens) != 0 + and out_tokens[-1] == self.speech_token_size + 2 + ): + lm_input = lm_input_text + else: + lm_input = paddle.cat([lm_input, lm_input_text], dim=1) + text_cache = text_cache[:, self.mix_ratio[0] :] + else: + logging.info("not enough text token to decode, wait for more") + continue + while True: + seq_len = ( + lm_input.shape[1] + if cache is None + else lm_input.shape[1] + cache[0][0].size(2) + ) + y_pred, cache = self.llm.forward_one_step( + lm_input, + masks=paddle.tril( + paddle.ones((1, seq_len, seq_len), device=lm_input.place) + ).to(paddle.bool), + cache=cache, + ) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + if next_fill_index != -1 and len(out_tokens) == next_fill_index: + top_ids = self.speech_token_size + 2 + next_fill_index += self.mix_ratio[1] + 1 + else: + top_ids = self.sampling_ids( + logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True + ).item() + if top_ids == self.speech_token_size + 2: + next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1 + logging.info( + "fill_token index {} next fill_token index {}".format( + len(out_tokens), next_fill_index + ) + ) + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size + 2: + break + else: + raise ValueError("should not get token {}".format(top_ids)) + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + lm_input = paddle.cat([lm_input, text_cache, task_id_emb], dim=1) + logging.info("no more text token, decode until met eos") + while True: + seq_len = ( + lm_input.shape[1] + if cache is None + else lm_input.shape[1] + cache[0][0].size(2) + ) + y_pred, cache = self.llm.forward_one_step( + lm_input, + masks=paddle.tril( + paddle.ones((1, seq_len, seq_len), device=lm_input.place) + ).to(paddle.bool), + cache=cache, + ) + logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1) + top_ids = self.sampling_ids( + logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False + ).item() + out_tokens.append(top_ids) + if top_ids >= self.speech_token_size: + if top_ids == self.speech_token_size: + break + else: + raise ValueError("should not get token {}".format(top_ids)) + yield top_ids + lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1) + + diff --git a/paddlespeech/t2s/models/CosyVoice/mask.py b/paddlespeech/t2s/models/CosyVoice/mask.py new file mode 100644 index 000000000..dd10276db --- /dev/null +++ b/paddlespeech/t2s/models/CosyVoice/mask.py @@ -0,0 +1,118 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +def add_optional_chunk_mask( + xs: paddle.Tensor, + masks: paddle.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int, + enable_full_context: bool = True, +): + """Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + enable_full_context (bool): + True: chunk size is either [1, 25] or full context(max_len) + False: chunk size ~ U[1, 25] + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + chunk_size = paddle.randint(low=1, high=max_len, shape=(1,)).item() + num_left_chunks = -1 + if chunk_size > max_len // 2 and enable_full_context: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = paddle.randint( + low=0, high=max_left_chunks, shape=(1,) + ).item() + chunk_masks = subsequent_chunk_mask( + xs.size(1), chunk_size, num_left_chunks, xs.place + ) + chunk_masks = chunk_masks.unsqueeze(0) + chunk_masks = masks & chunk_masks + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask( + xs.size(1), static_chunk_size, num_left_chunks, xs.place + ) + chunk_masks = chunk_masks.unsqueeze(0) + chunk_masks = masks & chunk_masks + else: + chunk_masks = masks + assert chunk_masks.dtype == paddle.bool + if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: + print( + "get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!" + ) + chunk_masks[chunk_masks.sum(dim=-1) == 0] = True + return chunk_masks + + +def make_pad_mask(lengths: paddle.Tensor, max_len: int = 0) -> paddle.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.shape[0] + max_len = max_len if max_len > 0 else lengths.max().item() + seq_range = paddle.arange(0, max_len, dtype=paddle.int32) + seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len]) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask \ No newline at end of file diff --git a/paddlespeech/t2s/models/CosyVoice/model.py b/paddlespeech/t2s/models/CosyVoice/model.py new file mode 100644 index 000000000..1fd07de8a --- /dev/null +++ b/paddlespeech/t2s/models/CosyVoice/model.py @@ -0,0 +1,607 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading +import time +import uuid +from contextlib import nullcontext +from typing import Generator + +import numpy as np +import paddle + + +class CosyVoiceModel: + def __init__( + self, + llm: paddle.nn.Layer, + flow: paddle.nn.Layer, + hift: paddle.nn.Layer, + fp16: bool = False, + ): + self.device = device2str( + "cuda" if paddle.device.cuda.device_count() >= 1 else "cpu" + ) + self.llm = llm + self.flow = flow + self.hift = hift + self.fp16 = fp16 + if self.fp16 is True: + self.llm.half() + self.flow.half() + self.token_min_hop_len = 2 * self.flow.input_frame_rate + self.token_max_hop_len = 4 * self.flow.input_frame_rate + self.token_overlap_len = 20 + self.mel_overlap_len = int( + self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256 + ) + self.mel_window = np.hamming(2 * self.mel_overlap_len) + self.mel_cache_len = 20 + self.source_cache_len = int(self.mel_cache_len * 256) + self.speech_window = np.hamming(2 * self.source_cache_len) + self.stream_scale_factor = 1 + assert ( + self.stream_scale_factor >= 1 + ), "stream_scale_factor should be greater than 1, change it according to your actual rtf" + self.llm_context = ( + paddle.device.stream_guard( + paddle.device.Stream(device=device2str(self.device)) + ) + if paddle.device.cuda.device_count() >= 1 + else nullcontext() + ) + self.lock = threading.Lock() + self.tts_speech_token_dict = {} + self.llm_end_dict = {} + self.mel_overlap_dict = {} + self.flow_cache_dict = {} + self.hift_cache_dict = {} + + def load(self, llm_model, flow_model, hift_model): + self.llm.set_state_dict(state_dict=paddle.load(path=str(llm_model))) + self.llm.to(self.device).eval() + self.flow.set_state_dict(state_dict=paddle.load(path=str(flow_model))) + self.flow.to(self.device).eval() + hift_state_dict = { + k.replace("generator.", ""): v + for k, v in paddle.load(path=str(hift_model)).items() + } + self.hift.set_state_dict(state_dict=hift_state_dict) + self.hift.to(self.device).eval() + + def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model): + llm_text_encoder = torch.jit.load( + llm_text_encoder_model, map_location=self.device + ) + self.llm.text_encoder = llm_text_encoder + llm_llm = torch.jit.load(llm_llm_model, map_location=self.device) + self.llm.llm = llm_llm + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder + + def load_trt( + self, + flow_decoder_estimator_model, + flow_decoder_onnx_model, + trt_concurrent, + fp16, + ): + assert paddle.device.cuda.device_count() >= 1, "tensorrt only supports gpu!" + if ( + not os.path.exists(flow_decoder_estimator_model) + or os.path.getsize(flow_decoder_estimator_model) == 0 + ): + convert_onnx_to_trt( + flow_decoder_estimator_model, + self.get_trt_kwargs(), + flow_decoder_onnx_model, + fp16, + ) + del self.flow.decoder.estimator + import tensorrt as trt + + with open(flow_decoder_estimator_model, "rb") as f: + estimator_engine = trt.Runtime( + trt.Logger(trt.Logger.INFO) + ).deserialize_cuda_engine(f.read()) + assert estimator_engine is not None, "failed to load trt {}".format( + flow_decoder_estimator_model + ) + self.flow.decoder.estimator = TrtContextWrapper( + estimator_engine, trt_concurrent=trt_concurrent, device=self.device + ) + + def get_trt_kwargs(self): + min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)] + opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)] + max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)] + input_names = ["x", "mask", "mu", "cond"] + return { + "min_shape": min_shape, + "opt_shape": opt_shape, + "max_shape": max_shape, + "input_names": input_names, + } + + def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid): + with self.llm_context, paddle.amp.auto_cast( + enable=self.fp16 is True and hasattr(self.llm, "vllm") is False + ): + if isinstance(text, Generator): + assert isinstance(self, CosyVoice2Model) and not hasattr( + self.llm, "vllm" + ), "streaming input text is only implemented for CosyVoice2 and do not support vllm!" + for i in self.llm.inference_bistream( + text=text, + prompt_text=prompt_text.to(self.device), + prompt_text_len=paddle.tensor( + [prompt_text.shape[1]], dtype=paddle.int32 + ).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=paddle.tensor( + [llm_prompt_speech_token.shape[1]], dtype=paddle.int32 + ).to(self.device), + embedding=llm_embedding.to(self.device), + ): + self.tts_speech_token_dict[uuid].append(i) + else: + for i in self.llm.inference( + text=text.to(self.device), + text_len=paddle.tensor([text.shape[1]], dtype=paddle.int32).to( + self.device + ), + prompt_text=prompt_text.to(self.device), + prompt_text_len=paddle.tensor( + [prompt_text.shape[1]], dtype=paddle.int32 + ).to(self.device), + prompt_speech_token=llm_prompt_speech_token.to(self.device), + prompt_speech_token_len=paddle.tensor( + [llm_prompt_speech_token.shape[1]], dtype=paddle.int32 + ).to(self.device), + embedding=llm_embedding.to(self.device), + uuid=uuid, + ): + self.tts_speech_token_dict[uuid].append(i) + self.llm_end_dict[uuid] = True + + def vc_job(self, source_speech_token, uuid): + self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist() + self.llm_end_dict[uuid] = True + + def token2wav( + self, + token, + prompt_token, + prompt_feat, + embedding, + uuid, + finalize=False, + speed=1.0, + ): + with paddle.amp.auto_cast(enable=self.fp16): + tts_mel, self.flow_cache_dict[uuid] = self.flow.inference( + token=token.to(self.device), + token_len=paddle.tensor([token.shape[1]], dtype=paddle.int32).to( + self.device + ), + prompt_token=prompt_token.to(self.device), + prompt_token_len=paddle.tensor( + [prompt_token.shape[1]], dtype=paddle.int32 + ).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=paddle.tensor( + [prompt_feat.shape[1]], dtype=paddle.int32 + ).to(self.device), + embedding=embedding.to(self.device), + flow_cache=self.flow_cache_dict[uuid], + ) + if self.mel_overlap_dict[uuid].shape[2] != 0: + tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) + if self.hift_cache_dict[uuid] is not None: + hift_cache_mel, hift_cache_source = ( + self.hift_cache_dict[uuid]["mel"], + self.hift_cache_dict[uuid]["source"], + ) + tts_mel = paddle.cat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = paddle.zeros([1, 1, 0]) + if finalize is False: + self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len :] + tts_mel = tts_mel[:, :, : -self.mel_overlap_len] + tts_speech, tts_source = self.hift.inference( + speech_feat=tts_mel, cache_source=hift_cache_source + ) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out( + tts_speech, self.hift_cache_dict[uuid]["speech"], self.speech_window + ) + self.hift_cache_dict[uuid] = { + "mel": tts_mel[:, :, -self.mel_cache_len :], + "source": tts_source[:, :, -self.source_cache_len :], + "speech": tts_speech[:, -self.source_cache_len :], + } + tts_speech = tts_speech[:, : -self.source_cache_len] + else: + if speed != 1.0: + assert ( + self.hift_cache_dict[uuid] is None + ), "speed change only support non-stream inference mode" + tts_mel = paddle.nn.functional.interpolate( + x=tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear" + ) + tts_speech, tts_source = self.hift.inference( + speech_feat=tts_mel, cache_source=hift_cache_source + ) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out( + tts_speech, self.hift_cache_dict[uuid]["speech"], self.speech_window + ) + return tts_speech + + def tts( + self, + text=paddle.zeros([1, 0], dtype=paddle.int32), + flow_embedding=paddle.zeros([0, 192]), + llm_embedding=paddle.zeros([0, 192]), + prompt_text=paddle.zeros([1, 0], dtype=paddle.int32), + llm_prompt_speech_token=paddle.zeros([1, 0], dtype=paddle.int32), + flow_prompt_speech_token=paddle.zeros([1, 0], dtype=paddle.int32), + prompt_speech_feat=paddle.zeros([1, 0, 80]), + source_speech_token=paddle.zeros([1, 0], dtype=paddle.int32), + stream=False, + speed=1.0, + **kwargs + ): + this_uuid = str(uuid.uuid1()) + with self.lock: + self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = ( + [], + False, + ) + self.hift_cache_dict[this_uuid] = None + self.mel_overlap_dict[this_uuid] = paddle.zeros([1, 80, 0]) + self.flow_cache_dict[this_uuid] = paddle.zeros([1, 80, 0, 2]) + if source_speech_token.shape[1] == 0: + p = threading.Thread( + target=self.llm_job, + args=( + text, + prompt_text, + llm_prompt_speech_token, + llm_embedding, + this_uuid, + ), + ) + else: + p = threading.Thread( + target=self.vc_job, args=(source_speech_token, this_uuid) + ) + """Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself""" + p.start() + if stream is True: + token_hop_len = self.token_min_hop_len + while True: + time.sleep(0.1) + if ( + len(self.tts_speech_token_dict[this_uuid]) + >= token_hop_len + self.token_overlap_len + ): + this_tts_speech_token = paddle.tensor( + self.tts_speech_token_dict[this_uuid][ + : token_hop_len + self.token_overlap_len + ] + ).unsqueeze(dim=0) + this_tts_speech = self.token2wav( + token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=False, + ) + yield {"tts_speech": this_tts_speech.cpu()} + with self.lock: + self.tts_speech_token_dict[ + this_uuid + ] = self.tts_speech_token_dict[this_uuid][token_hop_len:] + token_hop_len = min( + self.token_max_hop_len, + int(token_hop_len * self.stream_scale_factor), + ) + if ( + self.llm_end_dict[this_uuid] is True + and len(self.tts_speech_token_dict[this_uuid]) + < token_hop_len + self.token_overlap_len + ): + break + p.join() + this_tts_speech_token = paddle.tensor( + self.tts_speech_token_dict[this_uuid] + ).unsqueeze(dim=0) + this_tts_speech = self.token2wav( + token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True, + ) + yield {"tts_speech": this_tts_speech.cpu()} + else: + p.join() + this_tts_speech_token = paddle.tensor( + self.tts_speech_token_dict[this_uuid] + ).unsqueeze(dim=0) + this_tts_speech = self.token2wav( + token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + uuid=this_uuid, + finalize=True, + speed=speed, + ) + yield {"tts_speech": this_tts_speech.cpu()} + with self.lock: + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) + self.mel_overlap_dict.pop(this_uuid) + self.hift_cache_dict.pop(this_uuid) + self.flow_cache_dict.pop(this_uuid) + if paddle.device.cuda.device_count() >= 1: + paddle.device.cuda.empty_cache() + paddle.device.current_stream().synchronize() + + +class CosyVoice2Model(CosyVoiceModel): + def __init__( + self, + llm: paddle.nn.Layer, + flow: paddle.nn.Layer, + hift: paddle.nn.Layer, + fp16: bool = False, + ): + self.device = device2str( + "cuda" if paddle.device.cuda.device_count() >= 1 else "cpu" + ) + self.llm = llm + self.flow = flow + self.hift = hift + self.fp16 = fp16 + if self.fp16 is True: + self.llm.half() + self.flow.half() + self.token_hop_len = 25 + self.mel_cache_len = 8 + self.source_cache_len = int(self.mel_cache_len * 480) + self.speech_window = np.hamming(2 * self.source_cache_len) + self.llm_context = ( + paddle.device.stream_guard( + paddle.device.Stream(device=device2str(self.device)) + ) + if paddle.device.cuda.device_count() >= 1 + else nullcontext() + ) + self.lock = threading.Lock() + self.tts_speech_token_dict = {} + self.llm_end_dict = {} + self.hift_cache_dict = {} + + def load_jit(self, flow_encoder_model): + flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device) + self.flow.encoder = flow_encoder + + def load_vllm(self, model_dir): + export_cosyvoice2_vllm(self.llm, model_dir, self.device) + from vllm import EngineArgs, LLMEngine + + engine_args = EngineArgs( + model=model_dir, + skip_tokenizer_init=True, + enable_prompt_embeds=True, + gpu_memory_utilization=0.2, + ) + self.llm.vllm = LLMEngine.from_engine_args(engine_args) + self.llm.lock = threading.Lock() + del self.llm.llm.model.model.layers + + def token2wav( + self, + token, + prompt_token, + prompt_feat, + embedding, + token_offset, + uuid, + stream=False, + finalize=False, + speed=1.0, + ): + with paddle.amp.auto_cast(enable=self.fp16): + tts_mel, _ = self.flow.inference( + token=token.to(self.device), + token_len=paddle.tensor([token.shape[1]], dtype=paddle.int32).to( + self.device + ), + prompt_token=prompt_token.to(self.device), + prompt_token_len=paddle.tensor( + [prompt_token.shape[1]], dtype=paddle.int32 + ).to(self.device), + prompt_feat=prompt_feat.to(self.device), + prompt_feat_len=paddle.tensor( + [prompt_feat.shape[1]], dtype=paddle.int32 + ).to(self.device), + embedding=embedding.to(self.device), + streaming=stream, + finalize=finalize, + ) + tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio :] + if self.hift_cache_dict[uuid] is not None: + hift_cache_mel, hift_cache_source = ( + self.hift_cache_dict[uuid]["mel"], + self.hift_cache_dict[uuid]["source"], + ) + tts_mel = paddle.cat([hift_cache_mel, tts_mel], dim=2) + else: + hift_cache_source = paddle.zeros([1, 1, 0]) + if finalize is False: + tts_speech, tts_source = self.hift.inference( + speech_feat=tts_mel, cache_source=hift_cache_source + ) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out( + tts_speech, self.hift_cache_dict[uuid]["speech"], self.speech_window + ) + self.hift_cache_dict[uuid] = { + "mel": tts_mel[:, :, -self.mel_cache_len :], + "source": tts_source[:, :, -self.source_cache_len :], + "speech": tts_speech[:, -self.source_cache_len :], + } + tts_speech = tts_speech[:, : -self.source_cache_len] + else: + if speed != 1.0: + assert ( + self.hift_cache_dict[uuid] is None + ), "speed change only support non-stream inference mode" + tts_mel = paddle.nn.functional.interpolate( + x=tts_mel, size=int(tts_mel.shape[2] / speed), mode="linear" + ) + tts_speech, tts_source = self.hift.inference( + speech_feat=tts_mel, cache_source=hift_cache_source + ) + if self.hift_cache_dict[uuid] is not None: + tts_speech = fade_in_out( + tts_speech, self.hift_cache_dict[uuid]["speech"], self.speech_window + ) + return tts_speech + + def tts( + self, + text=paddle.zeros([1, 0], dtype=paddle.int32), + flow_embedding=paddle.zeros([0, 192]), + llm_embedding=paddle.zeros([0, 192]), + prompt_text=paddle.zeros([1, 0], dtype=paddle.int32), + llm_prompt_speech_token=paddle.zeros([1, 0], dtype=paddle.int32), + flow_prompt_speech_token=paddle.zeros([1, 0], dtype=paddle.int32), + prompt_speech_feat=paddle.zeros([1, 0, 80]), + source_speech_token=paddle.zeros([1, 0], dtype=paddle.int32), + stream=False, + speed=1.0, + **kwargs + ): + this_uuid = str(uuid.uuid1()) + with self.lock: + self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = ( + [], + False, + ) + self.hift_cache_dict[this_uuid] = None + if source_speech_token.shape[1] == 0: + p = threading.Thread( + target=self.llm_job, + args=( + text, + prompt_text, + llm_prompt_speech_token, + llm_embedding, + this_uuid, + ), + ) + else: + p = threading.Thread( + target=self.vc_job, args=(source_speech_token, this_uuid) + ) + """Not Support auto convert *.start, please judge whether it is Pytorch API and convert by yourself""" + p.start() + if stream is True: + token_offset = 0 + prompt_token_pad = int( + np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) + * self.token_hop_len + - flow_prompt_speech_token.shape[1] + ) + while True: + time.sleep(0.1) + this_token_hop_len = ( + self.token_hop_len + prompt_token_pad + if token_offset == 0 + else self.token_hop_len + ) + if ( + len(self.tts_speech_token_dict[this_uuid]) - token_offset + >= this_token_hop_len + self.flow.pre_lookahead_len + ): + this_tts_speech_token = paddle.tensor( + self.tts_speech_token_dict[this_uuid][ + : token_offset + + this_token_hop_len + + self.flow.pre_lookahead_len + ] + ).unsqueeze(dim=0) + this_tts_speech = self.token2wav( + token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + token_offset=token_offset, + uuid=this_uuid, + stream=stream, + finalize=False, + ) + token_offset += this_token_hop_len + yield {"tts_speech": this_tts_speech.cpu()} + if ( + self.llm_end_dict[this_uuid] is True + and len(self.tts_speech_token_dict[this_uuid]) - token_offset + < this_token_hop_len + self.flow.pre_lookahead_len + ): + break + p.join() + this_tts_speech_token = paddle.tensor( + self.tts_speech_token_dict[this_uuid] + ).unsqueeze(dim=0) + this_tts_speech = self.token2wav( + token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + token_offset=token_offset, + uuid=this_uuid, + finalize=True, + ) + yield {"tts_speech": this_tts_speech.cpu()} + else: + p.join() + this_tts_speech_token = paddle.tensor( + self.tts_speech_token_dict[this_uuid] + ).unsqueeze(dim=0) + this_tts_speech = self.token2wav( + token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + token_offset=0, + uuid=this_uuid, + finalize=True, + speed=speed, + ) + yield {"tts_speech": this_tts_speech.cpu()} + with self.lock: + self.tts_speech_token_dict.pop(this_uuid) + self.llm_end_dict.pop(this_uuid) + self.hift_cache_dict.pop(this_uuid) + if paddle.device.cuda.device_count() >= 1: + paddle.device.cuda.empty_cache() + paddle.device.current_stream().synchronize() \ No newline at end of file diff --git a/paddlespeech/t2s/models/__init__.py b/paddlespeech/t2s/models/__init__.py index d8df4368a..c01e79370 100644 --- a/paddlespeech/t2s/models/__init__.py +++ b/paddlespeech/t2s/models/__init__.py @@ -22,3 +22,4 @@ from .transformer_tts import * from .vits import * from .waveflow import * from .wavernn import * +from .CosyVoice import * diff --git a/paddlespeech/t2s/models/hifigan/__init__.py b/paddlespeech/t2s/models/hifigan/__init__.py index 7aa5e9d78..51c0924dd 100644 --- a/paddlespeech/t2s/models/hifigan/__init__.py +++ b/paddlespeech/t2s/models/hifigan/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. from .hifigan import * from .hifigan_updater import * +from .cosy_hifigan import * +from .f0_predictor import * \ No newline at end of file diff --git a/paddlespeech/t2s/models/hifigan/cosy_hifigan.py b/paddlespeech/t2s/models/hifigan/cosy_hifigan.py new file mode 100644 index 000000000..d722bf2eb --- /dev/null +++ b/paddlespeech/t2s/models/hifigan/cosy_hifigan.py @@ -0,0 +1,607 @@ +import paddle + +"""HIFI-GAN""" +from typing import Dict, List, Optional + +import numpy as np +from scipy.signal import get_window +from paddlespeech.t2s.modules.transformer.activation import Snake +from paddlespeech.t2s.models.CosyVoice.common import get_padding, init_weights + +"""hifigan based generator implementation. + +This code is modified from https://github.com/jik876/hifi-gan + ,https://github.com/kan-bayashi/ParallelWaveGAN and + https://github.com/NVIDIA/BigVGAN + +""" + + +class ResBlock(paddle.nn.Layer): + """Residual block module in HiFiGAN/BigVGAN.""" + + def __init__(self, channels: int=512, kernel_size: int=3, dilations: + List[int]=[1, 3, 5]): + super(ResBlock, self).__init__() + self.convs1 = paddle.nn.LayerList() + self.convs2 = paddle.nn.LayerList() + for dilation in dilations: + self.convs1.append(paddle.nn.Conv1D(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation))) + self.convs2.append(paddle.nn.Conv1D(channels, channels, kernel_size, 1, dilation=1,padding=get_padding(kernel_size, 1))) + self.convs1.apply(init_weights) + self.convs2.apply(init_weights) + self.activations1 = paddle.nn.LayerList(sublayers=[Snake(channels, + alpha_logscale=False) for _ in range(len(self.convs1))]) + self.activations2 = paddle.nn.LayerList(sublayers=[Snake(channels, + alpha_logscale=False) for _ in range(len(self.convs2))]) + + def forward(self, x: paddle.Tensor) ->paddle.Tensor: + for idx in range(len(self.convs1)): + xt = self.activations1[idx](x) + xt = self.convs1[idx](xt) + xt = self.activations2[idx](xt) + xt = self.convs2[idx](xt) + x = xt + x + return x + + def remove_weight_norm(self): + for idx in range(len(self.convs1)): + paddle.nn.utils.remove_weight_norm(layer=self.convs1[idx]) + paddle.nn.utils.remove_weight_norm(layer=self.convs2[idx]) + + +class SineGen(paddle.nn.Layer): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std= + 0.003, voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + uv = (f0 > self.voiced_threshold).astype(paddle.float32) + return uv + + @paddle.no_grad() + def forward(self, f0): + """ + :param f0: [B, 1, sample_len], Hz + :return: [B, 1, sample_len] + """ + F_mat = paddle.zeros([f0.size(0), self.harmonic_num + 1, f0.size(-1)]).to(f0.place) + for i in range(self.harmonic_num + 1): + F_mat[:, i:i + 1, :] = f0 * (i + 1) / self.sampling_rate + theta_mat = 2 * np.pi * (paddle.cumsum(F_mat, axis=-1) % 1) + u_dist = paddle.distribution.Uniform(low=-np.pi, high=np.pi) + phase_vec = u_dist.sample(shape=(f0.size(0), self.harmonic_num + 1, 1) + ).to(F_mat.place) + phase_vec[:, 0, :] = 0 + sine_waves = self.sine_amp * paddle.sin(theta_mat + phase_vec) + uv = self._f02uv(f0) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * paddle.randn(shape=sine_waves.shape, dtype= + sine_waves.dtype) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + +class SourceModuleHnNSF(paddle.nn.Layer): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + if sinegen_type == '1': + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod) + else: + self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal) + self.l_linear = paddle.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = paddle.nn.Tanh() + self.causal = causal + paddle.seed(1986) + + if causal is True: + self.uv = paddle.rand(shape=[1, 300 * 24000, 1]) + self.register_buffer('uv_buffer', self.uv) + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + with paddle.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x) + + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + if not self.training and self.causal: + noise = self.uv_buffer[:, :uv.shape[1]] * self.sine_amp / 3 + else: + noise = paddle.randn(shape=uv.shape, dtype=uv.dtype) * self.sine_amp / 3 + return sine_merge, noise, uv +# class SourceModuleHnNSF(paddle.nn.Layer): +# """ SourceModule for hn-nsf +# SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, +# add_noise_std=0.003, voiced_threshod=0) +# sampling_rate: sampling_rate in Hz +# harmonic_num: number of harmonic above F0 (default: 0) +# sine_amp: amplitude of sine source signal (default: 0.1) +# add_noise_std: std of additive Gaussian noise (default: 0.003) +# note that amplitude of noise in unvoiced is decided +# by sine_amp +# voiced_threshold: threhold to set U/V given F0 (default: 0) +# Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) +# F0_sampled (batchsize, length, 1) +# Sine_source (batchsize, length, 1) +# noise_source (batchsize, length 1) +# uv (batchsize, length, 1) +# """ + +# def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, +# sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0): +# super(SourceModuleHnNSF, self).__init__() +# self.sine_amp = sine_amp +# self.noise_std = add_noise_std +# self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, +# add_noise_std, voiced_threshod) +# self.l_linear = paddle.nn.Linear(in_features=harmonic_num + 1, +# out_features=1) +# self.l_tanh = paddle.nn.Tanh() + +# def forward(self, x): +# """ +# Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) +# F0_sampled (batchsize, length, 1) +# Sine_source (batchsize, length, 1) +# noise_source (batchsize, length 1) +# """ +# with paddle.no_grad(): +# sine_wavs, uv, _ = self.l_sin_gen(paddle.transpose(x,perm=[0,2,1])) +# sine_wavs = paddle.transpose(sine_wavs,perm=[0,2,1]) +# uv = paddle.transpose(uv,perm=[0,2,1]) +# sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + +# noise = paddle.randn(shape=uv.shape, dtype=uv.dtype +# ) * self.sine_amp / 3 +# return sine_merge, noise, uv + + +# class SineGen2(paddle.nn.Layer): +# """ Definition of sine generator +# SineGen(samp_rate, harmonic_num = 0, +# sine_amp = 0.1, noise_std = 0.003, +# voiced_threshold = 0, +# flag_for_pulse=False) +# samp_rate: sampling rate in Hz +# harmonic_num: number of harmonic overtones (default 0) +# sine_amp: amplitude of sine-wavefrom (default 0.1) +# noise_std: std of Gaussian noise (default 0.003) +# voiced_thoreshold: F0 threshold for U/V classification (default 0) +# flag_for_pulse: this SinGen is used inside PulseGen (default False) +# Note: when flag_for_pulse is True, the first time step of a voiced +# segment is always sin(np.pi) or cos(0) +# """ + +# def __init__(self, samp_rate, upsample_scale, harmonic_num=0, sine_amp= +# 0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False): +# super(SineGen2, self).__init__() +# self.sine_amp = sine_amp +# self.noise_std = noise_std +# self.harmonic_num = harmonic_num +# self.axis = self.harmonic_num + 1 +# self.sampling_rate = samp_rate +# self.voiced_threshold = voiced_threshold +# self.flag_for_pulse = flag_for_pulse +# self.upsample_scale = upsample_scale + +# def _f02uv(self, f0): +# uv = (f0 > self.voiced_threshold).astype(paddle.float32) +# return uv + +# def _f02sine(self, f0_values): +# """ f0_values: (batchsize, length, axis) +# where axis indicates fundamental tone and overtones +# """ +# rad_values = f0_values / self.sampling_rate % 1 +# rand_ini = paddle.rand(shape=[f0_values.shape[0], f0_values.shape[2]]) +# rand_ini[:, 0] = 0 +# rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini +# if not self.flag_for_pulse: +# x = paddle.transpose(rad_values,perm = [0,2,1]) + +# rad_values = paddle.transpose(paddle.nn.functional.interpolate(x=x, scale_factor=1 / self.upsample_scale, mode='linear'),perm = [0,2,1]) +# phase = paddle.cumsum(rad_values, axis=1) * 2 * np.pi +# phase = paddle.transpose(paddle.nn.functional.interpolate(x=paddle.transpose(phase,perm = [0,2,1]) * self.upsample_scale, scale_factor=int(self.upsample_scale),mode='linear'),perm = [0,2,1]) +# sines = paddle.sin(phase) +# else: +# uv = self._f02uv(f0_values) +# uv_1 = paddle.roll(uv, shifts=-1, axis=1) +# uv_1[:, -1, :] = 1 +# u_loc = (uv < 1) * (uv_1 > 0) +# tmp_cumsum = paddle.cumsum(rad_values, axis=1) +# for idx in range(f0_values.shape[0]): +# temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] +# temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] +# tmp_cumsum[idx, :, :] = 0 +# tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum +# i_phase = paddle.cumsum(rad_values - tmp_cumsum, axis=1) +# sines = paddle.cos(i_phase * 2 * np.pi) +# return sines + +# def forward(self, f0): +# """ sine_tensor, uv = forward(f0) +# input F0: tensor(batchsize=1, length, axis=1) +# f0 for unvoiced steps should be 0 +# output sine_tensor: tensor(batchsize=1, length, axis) +# output uv: tensor(batchsize=1, length, 1) +# """ +# paddle.seed(1986) +# fn = paddle.multiply(f0, paddle.to_tensor([[range(1, self.harmonic_num + +# 2)]],dtype='float32',place=f0.place)) + +# sine_waves = self._f02sine(fn) * self.sine_amp +# uv = self._f02uv(f0) +# noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 +# noise = noise_amp * paddle.randn(shape=sine_waves.shape, dtype= +# sine_waves.dtype) +# sine_waves = sine_waves * uv + noise + +# return sine_waves, uv, noise +class SineGen2(paddle.nn.Layer): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, upsample_scale, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False, + causal=False): + super(SineGen2, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.upsample_scale = upsample_scale + self.causal = causal + paddle.seed(1986) + if causal is True: + self.rand_ini = paddle.rand(shape=[1, 9]) + self.rand_ini[:, 0] = 0 + self.sine_waves = paddle.rand(shape=[1, 300 * 24000, 9]) + self.register_buffer('rand_ini_buffer', self.rand_ini) + self.register_buffer('sine_waves_buffer', self.sine_waves) + + def _f02uv(self, f0): + uv = (f0 > self.voiced_threshold).astype('float32') + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) """ + rad_values = (f0_values / self.sampling_rate) % 1 + if not self.training and self.causal: + rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini_buffer + else: + rand_ini = paddle.rand(shape=[f0_values.shape[0], f0_values.shape[2]]) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + if not self.flag_for_pulse: + scale_factor_down = 1.0 / self.upsample_scale + + rad_values = paddle.nn.functional.interpolate( + rad_values.transpose([0, 2, 1]), + scale_factor=scale_factor_down, + mode="linear" + ).transpose([0, 2, 1]) + phase = paddle.cumsum(rad_values, axis=1) * 2 * np.pi + + interpolate_mode = "nearest" if self.causal else 'linear' + phase = paddle.transpose(paddle.nn.functional.interpolate( + paddle.transpose(phase,perm=[0, 2, 1])*self.upsample_scale, + scale_factor=float(self.upsample_scale), + mode=interpolate_mode + ),perm = [0, 2, 1]) + + sines = paddle.sin(phase) + else: + uv = self._f02uv(f0_values) + + uv_1 = paddle.roll(uv, shifts=-1, axis=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + tmp_cumsum = paddle.cumsum(rad_values, axis=1) + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + i_phase = paddle.cumsum(rad_values - tmp_cumsum, axis=1) + sines = paddle.cos(i_phase * 2 * np.pi) + + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) """ + paddle.seed(1986) + harmonic_coeffs = paddle.to_tensor( + [list(range(1, self.harmonic_num + 2))], + dtype='float32' + ).reshape([1, 1, -1]) + + fn = f0 * harmonic_coeffs + + sine_waves = self._f02sine(fn) * self.sine_amp + uv = self._f02uv(f0) + + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + + if not self.training and self.causal: + noise = noise_amp * self.sine_waves_buffer[:, :sine_waves.shape[1]] + else: + noise = noise_amp * paddle.randn( + shape=sine_waves.shape, + dtype=sine_waves.dtype + ) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + +class SourceModuleHnNSF2(paddle.nn.Layer): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, + sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF2, self).__init__() + self.sine_amp = sine_amp + self.noise_std = add_noise_std + self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, + harmonic_num, sine_amp, add_noise_std, voiced_threshod) + self.l_linear = paddle.nn.Linear(in_features=harmonic_num + 1, + out_features=1) + self.l_tanh = paddle.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + paddle.seed(1986) + with paddle.no_grad(): + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + noise = paddle.randn(shape=uv.shape, dtype=uv.dtype + ) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class HiFTGenerator(paddle.nn.Layer): + """ + HiFTNet Generator: Neural Source Filter + ISTFTNet + https://arxiv.org/abs/2309.09493 + """ + + def __init__(self, in_channels: int=80, base_channels: int=512, + nb_harmonics: int=8, sampling_rate: int=22050, nsf_alpha: float=0.1, + nsf_sigma: float=0.003, nsf_voiced_threshold: float=10, + upsample_rates: List[int]=[8, 8], upsample_kernel_sizes: List[int]= + [16, 16], istft_params: Dict[str, int]={'n_fft': 16, 'hop_len': 4}, + resblock_kernel_sizes: List[int]=[3, 7, 11], + resblock_dilation_sizes: List[List[int]]=[[1, 3, 5], [1, 3, 5], [1, + 3, 5]], source_resblock_kernel_sizes: List[int]=[7, 11], + source_resblock_dilation_sizes: List[List[int]]=[[1, 3, 5], [1, 3, + 5]], lrelu_slope: float=0.1, audio_limit: float=0.99, f0_predictor: + paddle.nn.Layer=None): + super(HiFTGenerator, self).__init__() + self.out_channels = 1 + self.nb_harmonics = nb_harmonics + self.sampling_rate = sampling_rate + self.istft_params = istft_params + self.lrelu_slope = lrelu_slope + self.audio_limit = audio_limit + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=sampling_rate, + upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], + harmonic_num=nb_harmonics, + sine_amp=nsf_alpha, + add_noise_std=nsf_sigma, + voiced_threshod=nsf_voiced_threshold, + sinegen_type='1' if self.sampling_rate == 22050 else '2', + causal=False) + self.f0_upsamp = paddle.nn.Upsample(scale_factor=(1,int(np.prod( upsample_rates) * istft_params['hop_len']))) + self.conv_pre = paddle.nn.Conv1D(in_channels = in_channels, out_channels = base_channels, kernel_size=7, stride=1, padding=3) + self.ups = paddle.nn.LayerList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(paddle.nn.Conv1DTranspose(in_channels=base_channels // 2 ** i,out_channels=base_channels // 2 ** (i + 1), kernel_size=k,stride=u, padding=(k - u) // 2)) + self.source_downs = paddle.nn.LayerList() + self.source_resblocks = paddle.nn.LayerList() + downsample_rates = [1] + upsample_rates[::-1][:-1] + downsample_cum_rates = np.cumprod(downsample_rates) + for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], + source_resblock_kernel_sizes, source_resblock_dilation_sizes)): + if u == 1: + self.source_downs.append(paddle.nn.Conv1D(istft_params[ + 'n_fft'] + 2, base_channels // 2 ** (i + 1), 1, 1)) + else: + self.source_downs.append(paddle.nn.Conv1D( + in_channels=istft_params['n_fft'] + 2, + out_channels=base_channels // (2 ** (i + 1)), + kernel_size=(u * 2,), + stride=(u,), + padding=int(u // 2), + )) + self.source_resblocks.append(ResBlock(base_channels // 2 ** (i + + 1), k, d)) + self.resblocks = paddle.nn.LayerList() + for i in range(len(self.ups)): + ch = base_channels // 2 ** (i + 1) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, + resblock_dilation_sizes)): + self.resblocks.append(ResBlock(ch, k, d)) + self.conv_post = paddle.nn.Conv1D(ch, istft_params['n_fft'] + 2, 7, 1, padding=3) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.reflection_pad = paddle.nn.Pad1D(padding=(1, 0), mode='reflect') + self.stft_window = paddle.to_tensor(get_window('hann', + istft_params['n_fft'], fftbins=True).astype(np.float32)) + self.f0_predictor = f0_predictor + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + paddle.nn.utils.remove_weight_norm(layer=l) + for l in self.resblocks: + l.remove_weight_norm() + paddle.nn.utils.remove_weight_norm(layer=self.conv_pre) + paddle.nn.utils.remove_weight_norm(layer=self.conv_post) + self.m_source.remove_weight_norm() + for l in self.source_downs: + paddle.nn.utils.remove_weight_norm(layer=l) + for l in self.source_resblocks: + l.remove_weight_norm() + + def _stft(self, x): + spec = paddle.signal.stft(x=x, n_fft=self.istft_params['n_fft'], + hop_length=self.istft_params['hop_len'], win_length=self. + istft_params['n_fft'], window=self.stft_window.to(x.place)) + spec = paddle.as_real(spec) + return spec[..., 0], spec[..., 1] + + def _istft(self, magnitude, phase): + magnitude = paddle.clip(magnitude, max=100.0) + real = magnitude * paddle.cos(phase) + img = magnitude * paddle.sin(phase) + inverse_transform = paddle.signal.istft(x=paddle.complex(real, img), + n_fft=self.istft_params['n_fft'], hop_length=self.istft_params[ + 'hop_len'], win_length=self.istft_params['n_fft'], window=self. + stft_window.to(magnitude.place)) + return inverse_transform + + def decode(self, x: paddle.Tensor, s: paddle.Tensor=paddle.zeros([1, 1, 0]) + ) ->paddle.Tensor: + s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) + s_stft = paddle.cat([s_stft_real, s_stft_imag], dim=1) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = paddle.nn.functional.leaky_relu(x=x, negative_slope=self. + lrelu_slope) + x = self.ups[i](x) + if i == self.num_upsamples - 1: + x = self.reflection_pad(x) + si = self.source_downs[i](s_stft) + si = self.source_resblocks[i](si) + x = x + si + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = paddle.nn.functional.leaky_relu(x=x) + x = self.conv_post(x) + magnitude = paddle.exp(x=x[:, :self.istft_params['n_fft'] // 2 + 1, :]) + phase = paddle.sin(x[:, self.istft_params['n_fft'] // 2 + 1:, :]) + x = self._istft(magnitude, phase) + x = paddle.clip(x, -self.audio_limit, self.audio_limit) + return x + + def forward(self, batch: dict) ->Dict[str, + Optional[paddle.Tensor]]: + speech_feat = paddle.transpose(batch['speech_feat'],perm = [0,2,1]).to(device) + f0 = self.f0_predictor(speech_feat) + s = paddle.transpose(self.f0_upsamp(f0[:, None]),perm = [0,2,1]) + s, _, _ = self.m_source(s) + s = paddle.transpose(s,perm = [0,2,1]) + + generated_speech = self.decode(x=speech_feat, s=s) + return generated_speech, f0 + + @paddle.no_grad() + def inference(self, speech_feat: paddle.Tensor, cache_source: paddle. + Tensor=paddle.zeros([1, 1, 0])) ->paddle.Tensor: + paddle.seed(1986) + f0 = self.f0_predictor(speech_feat) + f0_4d = f0[:, None].unsqueeze(2) + s_4d = self.f0_upsamp(f0_4d) + s_3d = s_4d.squeeze(2) + s = paddle.transpose(s_3d, perm=[0, 2, 1]) + + s, _, _ = self.m_source(s) + s = paddle.transpose(s,perm = [0,2,1]) + + if cache_source.shape[2] != 0: + s[:, :, :cache_source.shape[2]] = cache_source + generated_speech = self.decode(x=speech_feat, s=s) + + return generated_speech, s diff --git a/paddlespeech/t2s/models/hifigan/f0_predictor.py b/paddlespeech/t2s/models/hifigan/f0_predictor.py new file mode 100644 index 000000000..f5e922c1a --- /dev/null +++ b/paddlespeech/t2s/models/hifigan/f0_predictor.py @@ -0,0 +1,27 @@ +import paddle +class ConvRNNF0Predictor(paddle.nn.Layer): + + def __init__(self, num_class: int=1, in_channels: int=80, cond_channels: + int=512): + super().__init__() + self.num_class = num_class + self.condnet = paddle.nn.Sequential( + paddle.nn.Conv1D(in_channels, cond_channels, kernel_size=3, padding=1), + paddle.nn.ELU(), + paddle.nn.Conv1D(cond_channels, cond_channels,kernel_size=3, padding=1), + paddle.nn.ELU(), + paddle.nn.Conv1D(cond_channels, cond_channels,kernel_size=3, padding=1), + paddle.nn.ELU(), + paddle.nn.Conv1D(cond_channels, cond_channels,kernel_size=3, padding=1), + paddle.nn.ELU(), + paddle.nn.Conv1D(cond_channels, cond_channels,kernel_size=3, padding=1), + paddle.nn.ELU() + ) + self.classifier = paddle.nn.Linear(in_features=cond_channels, + out_features=self.num_class) + + def forward(self, x: paddle.Tensor) ->paddle.Tensor: + for idx,layer in enumerate(self.condnet): + x = layer(x) + x = paddle.transpose(x, perm=[0, 2, 1]) + return paddle.abs(x=self.classifier(x).squeeze(-1)) diff --git a/paddlespeech/t2s/modules/activation.py b/paddlespeech/t2s/modules/activation.py index f1c099b76..2ff434e3b 100644 --- a/paddlespeech/t2s/modules/activation.py +++ b/paddlespeech/t2s/modules/activation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddlespeech/t2s/modules/conv.py b/paddlespeech/t2s/modules/conv.py index 922af03f2..e1c3e7015 100644 --- a/paddlespeech/t2s/modules/conv.py +++ b/paddlespeech/t2s/modules/conv.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -124,7 +124,7 @@ class Conv1dCell(nn.Conv1D): self._reshaped_weight = paddle.reshape(self.weight, (self._out_channels, -1)) - def initialize_buffer(self, x_t): + def initialize_buffer(self, x_t): """Initialize the buffer for the step input. Args: diff --git a/paddlespeech/t2s/modules/decoder.py b/paddlespeech/t2s/modules/decoder.py new file mode 100644 index 000000000..b7fa01fc2 --- /dev/null +++ b/paddlespeech/t2s/modules/decoder.py @@ -0,0 +1,9 @@ +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.transpose(x, self.dim0, self.dim1) + return x \ No newline at end of file diff --git a/paddlespeech/t2s/modules/flow/__init__.py b/paddlespeech/t2s/modules/flow/__init__.py new file mode 100644 index 000000000..2e3d46a07 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/__init__.py @@ -0,0 +1 @@ +from .flow import CausalMaskedDiffWithXvec \ No newline at end of file diff --git a/paddlespeech/t2s/modules/flow/attention.py b/paddlespeech/t2s/modules/flow/attention.py new file mode 100644 index 000000000..f4f87c4fc --- /dev/null +++ b/paddlespeech/t2s/modules/flow/attention.py @@ -0,0 +1,386 @@ +from typing import Any, Dict, Optional +import paddle +from .attention_processor import Attention + +def _chunked_feed_forward( + ff: paddle.nn.Layer, hidden_states: paddle.Tensor, chunk_dim: int, chunk_size: int +): + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = paddle.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + +class SinusoidalPositionalEmbedding(paddle.nn.Layer): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = paddle.arange(max_seq_length).unsqueeze(1) + div_term = paddle.exp( + x=paddle.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim) + ) + pe = paddle.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = paddle.sin(position * div_term) + pe[0, :, 1::2] = paddle.cos(position * div_term) + self.register_buffer(name="pe", tensor=pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x + +class GatedSelfAttentionDense(paddle.nn.Layer): + """ + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + self.linear = paddle.nn.Linear(in_features=context_dim, out_features=query_dim) + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + self.norm1 = paddle.nn.LayerNorm(normalized_shape=query_dim) + self.norm2 = paddle.nn.LayerNorm(normalized_shape=query_dim) + self.add_parameter( + name="alpha_attn", + parameter=paddle.nn.parameter.Parameter(paddle.tensor(0.0)), + ) + self.add_parameter( + name="alpha_dense", + parameter=paddle.nn.parameter.Parameter(paddle.tensor(0.0)), + ) + self.enabled = True + + def forward(self, x: paddle.Tensor, objs: paddle.Tensor) -> paddle.Tensor: + if not self.enabled: + return x + n_visual = x.shape[1] + objs = self.linear(objs) + x = ( + x + + self.alpha_attn.tanh() + * self.attn(self.norm1(paddle.cat([x, objs], dim=1)))[:, :n_visual, :] + ) + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + return x + + +class BasicTransformerBlock(paddle.nn.Layer): + """ + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + norm_eps: float = 1e-05, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None and norm_type == "ada_norm_zero" + ) + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None and norm_type == "ada_norm" + ) + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + self.norm_type = norm_type + self.num_embeds_ada_norm = num_embeds_ada_norm + if positional_embeddings and num_positional_embeddings is None: + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding( + dim, max_seq_length=num_positional_embeddings + ) + else: + self.pos_embed = None + + + self.norm1 = paddle.nn.LayerNorm( + normalized_shape=dim, + weight_attr=norm_elementwise_affine, + bias_attr=norm_elementwise_affine, + epsilon=norm_eps, + ) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + if cross_attention_dim is not None or double_self_attention: + if norm_type == "ada_norm": + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif norm_type == "ada_norm_continuous": + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = paddle.nn.LayerNorm( + normalized_shape=dim, + epsilon=norm_eps, + weight_attr=norm_elementwise_affine, + bias_attr=norm_elementwise_affine, + ) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim + if not double_self_attention + else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + else: + self.norm2 = None + self.attn2 = None + if norm_type == "ada_norm_continuous": + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + elif norm_type in [ + "ada_norm_zero", + "ada_norm", + "layer_norm", + "ada_norm_continuous", + ]: + self.norm3 = paddle.nn.LayerNorm( + normalized_shape=dim, + epsilon=norm_eps, + weight_attr=norm_elementwise_affine, + bias_attr=norm_elementwise_affine, + ) + elif norm_type == "layer_norm_i2vgen": + self.norm3 = None + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense( + dim, cross_attention_dim, num_attention_heads, attention_head_dim + ) + if norm_type == "ada_norm_single": + self.scale_shift_table = paddle.nn.parameter.Parameter( + paddle.randn(6, dim) / dim**0.5 + ) + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + encoder_hidden_states: Optional[paddle.Tensor] = None, + encoder_attention_mask: Optional[paddle.Tensor] = None, + timestep: Optional[paddle.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[paddle.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, paddle.Tensor]] = None, + ) -> paddle.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored." + ) + batch_size = hidden_states.shape[0] + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.norm_type == "ada_norm_zero": + (norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm1(hidden_states) + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm1( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + elif self.norm_type == "ada_norm_single": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.norm_type == "ada_norm_zero": + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.norm_type == "ada_norm_single": + attn_output = gate_msa * attn_output + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + if self.attn2 is not None: + if self.norm_type == "ada_norm": + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: + norm_hidden_states = self.norm2(hidden_states) + elif self.norm_type == "ada_norm_single": + norm_hidden_states = hidden_states + elif self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm2( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + else: + raise ValueError("Incorrect norm") + if self.pos_embed is not None and self.norm_type != "ada_norm_single": + norm_hidden_states = self.pos_embed(norm_hidden_states) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + if self.norm_type == "ada_norm_continuous": + norm_hidden_states = self.norm3( + hidden_states, added_cond_kwargs["pooled_text_emb"] + ) + elif not self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm3(hidden_states) + if self.norm_type == "ada_norm_zero": + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + if self.norm_type == "ada_norm_single": + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + if self._chunk_size is not None: + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + ff_output = self.ff(norm_hidden_states) + if self.norm_type == "ada_norm_zero": + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.norm_type == "ada_norm_single": + ff_output = gate_mlp * ff_output + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + return hidden_states + + diff --git a/paddlespeech/t2s/modules/flow/attention_processor.py b/paddlespeech/t2s/modules/flow/attention_processor.py new file mode 100644 index 000000000..f34998083 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/attention_processor.py @@ -0,0 +1,625 @@ +import inspect +import math +from typing import Callable, List, Optional, Union +import paddle + +class Attention(paddle.nn.Layer): + """ + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-05, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self._from_deprecated_attn_block = _from_deprecated_attn_block + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + if norm_num_groups is not None: + self.group_norm = paddle.nn.GroupNorm( + num_channels=query_dim, + num_groups=norm_num_groups, + epsilon=eps, + weight_attr=True, + bias_attr=True, + ) + else: + self.group_norm = None + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm( + f_channels=query_dim, zq_channels=spatial_norm_dim + ) + else: + self.spatial_norm = None + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = paddle.nn.LayerNorm(normalized_shape=dim_head, epsilon=eps) + self.norm_k = paddle.nn.LayerNorm(normalized_shape=dim_head, epsilon=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'" + ) + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = paddle.nn.LayerNorm( + normalized_shape=self.cross_attention_dim + ) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + self.norm_cross = paddle.nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + epsilon=1e-05, + weight_attr=True, + bias_attr=True, + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + self.to_q = paddle.nn.Linear( + in_features=query_dim, out_features=self.inner_dim, bias_attr=bias + ) + if not self.only_cross_attention: + self.to_k = paddle.nn.Linear( + in_features=self.cross_attention_dim, + out_features=self.inner_dim, + bias_attr=bias, + ) + self.to_v = paddle.nn.Linear( + in_features=self.cross_attention_dim, + out_features=self.inner_dim, + bias_attr=bias, + ) + else: + self.to_k = None + self.to_v = None + if self.added_kv_proj_dim is not None: + self.add_k_proj = paddle.nn.Linear( + in_features=added_kv_proj_dim, out_features=self.inner_dim + ) + self.add_v_proj = paddle.nn.Linear( + in_features=added_kv_proj_dim, out_features=self.inner_dim + ) + if self.context_pre_only is not None: + self.add_q_proj = paddle.nn.Linear( + in_features=added_kv_proj_dim, out_features=self.inner_dim + ) + self.to_out = paddle.nn.LayerList(sublayers=[]) + self.to_out.append( + paddle.nn.Linear( + in_features=self.inner_dim, + out_features=self.out_dim, + bias_attr=out_bias, + ) + ) + self.to_out.append(paddle.nn.Dropout(p=dropout)) + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = paddle.nn.Linear( + in_features=self.inner_dim, + out_features=self.out_dim, + bias_attr=out_bias, + ) + processor = AttnProcessor() + processor = AttnProcessor2_0() + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + """ + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + if ( + hasattr(self, "processor") + and isinstance(self.processor, paddle.nn.Layer) + and not isinstance(processor, paddle.nn.Layer) + ): + self._modules.pop("processor") + self.processor = processor + + def forward( + self, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + **cross_attention_kwargs, + ) -> paddle.Tensor: + """ + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + attn_parameters = set( + inspect.signature(self.processor.__call__).parameters.keys() + ) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k + for k, _ in cross_attention_kwargs.items() + if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters + } + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: paddle.Tensor) -> paddle.Tensor: + """ + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim( + self, tensor: paddle.Tensor, out_dim: int = 3 + ) -> paddle.Tensor: + """ + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape( + batch_size, seq_len * extra_dim, head_size, dim // head_size + ) + tensor = tensor.permute(0, 2, 1, 3) + if out_dim == 3: + tensor = tensor.reshape( + batch_size * head_size, seq_len * extra_dim, dim // head_size + ) + return tensor + + def get_attention_scores( + self, + query: paddle.Tensor, + key: paddle.Tensor, + attention_mask: paddle.Tensor = None, + ) -> paddle.Tensor: + """ + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + if attention_mask is None: + baddbmm_input = paddle.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.place, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + attention_scores = paddle.baddbmm( + input=baddbmm_input, + x=query, + y=key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + if self.upcast_softmax: + attention_scores = attention_scores.float() + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + attention_probs = attention_probs.to(dtype) + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: paddle.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> paddle.Tensor: + """ + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = paddle.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.place, + ) + attention_mask = paddle.cat([attention_mask, padding], dim=2) + else: + attention_mask = paddle.compat.pad( + attention_mask, (0, target_length), value=0.0 + ) + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, axis=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + return attention_mask + + def norm_encoder_hidden_states( + self, encoder_hidden_states: paddle.Tensor + ) -> paddle.Tensor: + """ + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert ( + self.norm_cross is not None + ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + if isinstance(self.norm_cross, paddle.nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, paddle.nn.GroupNorm): + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + return encoder_hidden_states + + @paddle.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.place + dtype = self.to_q.weight.data.dtype + if not self.is_cross_attention: + concatenated_weights = paddle.cat( + [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + self.to_qkv = paddle.nn.Linear( + in_features=in_features, + out_features=out_features, + bias_attr=self.use_bias, + ) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = paddle.cat( + [self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data] + ) + self.to_qkv.bias.copy_(concatenated_bias) + else: + concatenated_weights = paddle.cat( + [self.to_k.weight.data, self.to_v.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + self.to_kv = paddle.nn.Linear( + in_features=in_features, + out_features=out_features, + bias_attr=self.use_bias, + ) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = paddle.cat( + [self.to_k.bias.data, self.to_v.bias.data] + ) + self.to_kv.bias.copy_(concatenated_bias) + self.fused_projections = fuse + + +class AttnProcessor: + """ + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> paddle.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = paddle.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + +class AttnProcessor2_0: + """ + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + pass + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> paddle.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + attention_mask = attention_mask.view( + [batch_size, attn.heads, -1, attention_mask.shape[-1]] + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = paddle.transpose(query.view([batch_size, -1, attn.heads, head_dim]),perm = [0,2,1]) + key = paddle.transpose(key.view([batch_size, -1, attn.heads, head_dim]),perm = [0,2,1]) + value =paddle.transpose(value.view([batch_size, -1, attn.heads, head_dim]),perm = [0,2,1]) + hidden_states = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + hidden_states = paddle.transpose(hidden_states,perm = [0,2,1]).reshape( + [batch_size, -1, attn.heads * head_dim] + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states =paddle.transpose(hidden_states,perm = [0,1,3,2]).reshape( + [batch_size, channel, height, width] + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states diff --git a/paddlespeech/t2s/modules/flow/attention_processor_back.py b/paddlespeech/t2s/modules/flow/attention_processor_back.py new file mode 100644 index 000000000..5621efe33 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/attention_processor_back.py @@ -0,0 +1,3015 @@ +import inspect +import math +from importlib import import_module +from typing import Callable, List, Optional, Union + +import paddle + +from ..image_processor import IPAdapterMaskProcessor +from ..utils import deprecate, logging +from ..utils.import_utils import is_torch_npu_available, is_xformers_available +from ..utils.torch_utils import maybe_allow_in_graph +from .lora import LoRALinearLayer + +logger = logging.get_logger(__name__) +if is_torch_npu_available(): + import torch_npu +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@maybe_allow_in_graph +class Attention(paddle.nn.Layer): + """ + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-05, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + context_pre_only=None, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self._from_deprecated_attn_block = _from_deprecated_attn_block + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + if norm_num_groups is not None: + self.group_norm = paddle.nn.GroupNorm( + num_channels=query_dim, + num_groups=norm_num_groups, + epsilon=eps, + weight_attr=True, + bias_attr=True, + ) + else: + self.group_norm = None + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm( + f_channels=query_dim, zq_channels=spatial_norm_dim + ) + else: + self.spatial_norm = None + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = paddle.nn.LayerNorm(normalized_shape=dim_head, epsilon=eps) + self.norm_k = paddle.nn.LayerNorm(normalized_shape=dim_head, epsilon=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'" + ) + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = paddle.nn.LayerNorm( + normalized_shape=self.cross_attention_dim + ) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + self.norm_cross = paddle.nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + epsilon=1e-05, + weight_attr=True, + bias_attr=True, + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + self.to_q = paddle.nn.Linear( + in_features=query_dim, out_features=self.inner_dim, bias_attr=bias + ) + if not self.only_cross_attention: + self.to_k = paddle.nn.Linear( + in_features=self.cross_attention_dim, + out_features=self.inner_dim, + bias_attr=bias, + ) + self.to_v = paddle.nn.Linear( + in_features=self.cross_attention_dim, + out_features=self.inner_dim, + bias_attr=bias, + ) + else: + self.to_k = None + self.to_v = None + if self.added_kv_proj_dim is not None: + self.add_k_proj = paddle.nn.Linear( + in_features=added_kv_proj_dim, out_features=self.inner_dim + ) + self.add_v_proj = paddle.nn.Linear( + in_features=added_kv_proj_dim, out_features=self.inner_dim + ) + if self.context_pre_only is not None: + self.add_q_proj = paddle.nn.Linear( + in_features=added_kv_proj_dim, out_features=self.inner_dim + ) + self.to_out = paddle.nn.LayerList(sublayers=[]) + self.to_out.append( + paddle.nn.Linear( + in_features=self.inner_dim, + out_features=self.out_dim, + bias_attr=out_bias, + ) + ) + self.to_out.append(paddle.nn.Dropout(p=dropout)) + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = paddle.nn.Linear( + in_features=self.inner_dim, + out_features=self.out_dim, + bias_attr=out_bias, + ) + if processor is None: + processor = ( + AttnProcessor2_0() +>>>>>> if hasattr(torch.nn.functional, "scaled_dot_product_attention") + and self.scale_qk + else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: + """ + Set whether to use npu flash attention from `torch_npu` or not. + + """ + if use_npu_flash_attention: + processor = AttnProcessorNPU() + else: + processor = ( + AttnProcessor2_0() +>>>>>> if hasattr(torch.nn.functional, "scaled_dot_product_attention") + and self.scale_qk + else AttnProcessor() + ) + self.set_processor(processor) + + def set_use_memory_efficient_attention_xformers( + self, + use_memory_efficient_attention_xformers: bool, + attention_op: Optional[Callable] = None, + ) -> None: + """ + Set whether to use memory efficient attention from `xformers` or not. + + Args: + use_memory_efficient_attention_xformers (`bool`): + Whether to use memory efficient attention from `xformers` or not. + attention_op (`Callable`, *optional*): + The attention operation to use. Defaults to `None` which uses the default attention operation from + `xformers`. + """ + is_lora = hasattr(self, "processor") and isinstance( + self.processor, LORA_ATTENTION_PROCESSORS + ) + is_custom_diffusion = hasattr(self, "processor") and isinstance( + self.processor, + ( + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, + CustomDiffusionAttnProcessor2_0, + ), + ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) + if use_memory_efficient_attention_xformers: + if is_added_kv_processor and (is_lora or is_custom_diffusion): + raise NotImplementedError( + f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" + ) + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers", + name="xformers", + ) + elif not paddle.device.cuda.device_count() >= 1: + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only available for GPU " + ) + else: + try: + _ = xformers.ops.memory_efficient_attention( + paddle.randn((1, 2, 40), device="cuda"), + paddle.randn((1, 2, 40), device="cuda"), + paddle.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + if is_lora: + processor = LoRAXFormersAttnProcessor( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + attention_op=attention_op, + ) + processor.set_state_dict(state_dict=self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.place) + elif is_custom_diffusion: + processor = CustomDiffusionXFormersAttnProcessor( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + attention_op=attention_op, + ) + processor.set_state_dict(state_dict=self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.place) + elif is_added_kv_processor: + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) + else: + processor = XFormersAttnProcessor(attention_op=attention_op) + elif is_lora: + attn_processor_class = ( + LoRAAttnProcessor2_0 +>>>>>> if hasattr(torch.nn.functional, "scaled_dot_product_attention") + else LoRAAttnProcessor + ) + processor = attn_processor_class( + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + rank=self.processor.rank, + ) + processor.set_state_dict(state_dict=self.processor.state_dict()) + processor.to(self.processor.to_q_lora.up.weight.place) + elif is_custom_diffusion: + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 +>>>>>> if hasattr(torch.nn.functional, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( + train_kv=self.processor.train_kv, + train_q_out=self.processor.train_q_out, + hidden_size=self.processor.hidden_size, + cross_attention_dim=self.processor.cross_attention_dim, + ) + processor.set_state_dict(state_dict=self.processor.state_dict()) + if hasattr(self.processor, "to_k_custom_diffusion"): + processor.to(self.processor.to_k_custom_diffusion.weight.place) + else: + processor = ( + AttnProcessor2_0() +>>>>>> if hasattr(torch.nn.functional, "scaled_dot_product_attention") + and self.scale_qk + else AttnProcessor() + ) + self.set_processor(processor) + + def set_attention_slice(self, slice_size: int) -> None: + """ + Set the slice size for attention computation. + + Args: + slice_size (`int`): + The slice size for attention computation. + """ + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}." + ) + if slice_size is not None and self.added_kv_proj_dim is not None: + processor = SlicedAttnAddedKVProcessor(slice_size) + elif slice_size is not None: + processor = SlicedAttnProcessor(slice_size) + elif self.added_kv_proj_dim is not None: + processor = AttnAddedKVProcessor() + else: + processor = ( + AttnProcessor2_0() +>>>>>> if hasattr(torch.nn.functional, "scaled_dot_product_attention") + and self.scale_qk + else AttnProcessor() + ) + self.set_processor(processor) + + def set_processor(self, processor: "AttnProcessor") -> None: + """ + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + if ( + hasattr(self, "processor") + and isinstance(self.processor, paddle.nn.Layer) + and not isinstance(processor, paddle.nn.Layer) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}" + ) + self._modules.pop("processor") + self.processor = processor + + def get_processor( + self, return_deprecated_lora: bool = False + ) -> "AttentionProcessor": + """ + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + is_lora_activated = { + name: (module.lora_layer is not None) + for name, module in self.named_sublayers(include_self=True) + if hasattr(module, "lora_layer") + } + if not any(is_lora_activated.values()): + return self.processor + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr( + import_module(__name__), "LoRA" + non_lora_processor_cls_name + ) + hidden_size = self.inner_dim + if lora_processor_cls in [ + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + ]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict( + self.add_k_proj.lora_layer.state_dict() + ) + lora_processor.add_v_proj_lora.load_state_dict( + self.add_v_proj.lora_layer.state_dict() + ) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + return lora_processor + + def forward( + self, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + **cross_attention_kwargs, + ) -> paddle.Tensor: + """ + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + print("processor:", self.processor, "2" * 200) + attn_parameters = set( + inspect.signature(self.processor.__call__).parameters.keys() + ) + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k + for k, _ in cross_attention_kwargs.items() + if k not in attn_parameters and k not in quiet_attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters + } + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: paddle.Tensor) -> paddle.Tensor: + """ + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim( + self, tensor: paddle.Tensor, out_dim: int = 3 + ) -> paddle.Tensor: + """ + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape( + batch_size, seq_len * extra_dim, head_size, dim // head_size + ) + tensor = tensor.permute(0, 2, 1, 3) + if out_dim == 3: + tensor = tensor.reshape( + batch_size * head_size, seq_len * extra_dim, dim // head_size + ) + return tensor + + def get_attention_scores( + self, + query: paddle.Tensor, + key: paddle.Tensor, + attention_mask: paddle.Tensor = None, + ) -> paddle.Tensor: + """ + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + if attention_mask is None: + baddbmm_input = paddle.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.place, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + attention_scores = paddle.baddbmm( + input=baddbmm_input, + x=query, + y=key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + if self.upcast_softmax: + attention_scores = attention_scores.float() + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + attention_probs = attention_probs.to(dtype) + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: paddle.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> paddle.Tensor: + """ + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = paddle.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.place, + ) + attention_mask = paddle.cat([attention_mask, padding], dim=2) + else: + attention_mask = paddle.compat.pad( + attention_mask, (0, target_length), value=0.0 + ) + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + return attention_mask + + def norm_encoder_hidden_states( + self, encoder_hidden_states: paddle.Tensor + ) -> paddle.Tensor: + """ + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert ( + self.norm_cross is not None + ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + if isinstance(self.norm_cross, paddle.nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, paddle.nn.GroupNorm): + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + return encoder_hidden_states + + @paddle.no_grad() + def fuse_projections(self, fuse=True): + device = self.to_q.weight.data.place + dtype = self.to_q.weight.data.dtype + if not self.is_cross_attention: + concatenated_weights = paddle.cat( + [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + self.to_qkv = paddle.nn.Linear( + in_features=in_features, + out_features=out_features, + bias_attr=self.use_bias, + ) + self.to_qkv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = paddle.cat( + [self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data] + ) + self.to_qkv.bias.copy_(concatenated_bias) + else: + concatenated_weights = paddle.cat( + [self.to_k.weight.data, self.to_v.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + self.to_kv = paddle.nn.Linear( + in_features=in_features, + out_features=out_features, + bias_attr=self.use_bias, + ) + self.to_kv.weight.copy_(concatenated_weights) + if self.use_bias: + concatenated_bias = paddle.cat( + [self.to_k.bias.data, self.to_v.bias.data] + ) + self.to_kv.bias.copy_(concatenated_bias) + self.fused_projections = fuse + + +class AttnProcessor: + """ + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> paddle.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = paddle.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class CustomDiffusionAttnProcessor(paddle.nn.Layer): + """ + Processor for implementing attention for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + + def __init__( + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + if self.train_kv: + self.to_k_custom_diffusion = paddle.nn.Linear( + in_features=cross_attention_dim or hidden_size, + out_features=hidden_size, + bias_attr=False, + ) + self.to_v_custom_diffusion = paddle.nn.Linear( + in_features=cross_attention_dim or hidden_size, + out_features=hidden_size, + bias_attr=False, + ) + if self.train_q_out: + self.to_q_custom_diffusion = paddle.nn.Linear( + in_features=hidden_size, out_features=hidden_size, bias_attr=False + ) + self.to_out_custom_diffusion = paddle.nn.LayerList(sublayers=[]) + self.to_out_custom_diffusion.append( + paddle.nn.Linear( + in_features=hidden_size, + out_features=hidden_size, + bias_attr=out_bias, + ) + ) + self.to_out_custom_diffusion.append(paddle.nn.Dropout(p=dropout)) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) + else: + query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + if self.train_kv: + key = self.to_k_custom_diffusion( + encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype) + ) + value = self.to_v_custom_diffusion( + encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype) + ) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + if crossattn: + detach = paddle.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = paddle.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + if self.train_q_out: + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class AttnAddedKVProcessor: + """ + Processor for performing attention-related computations with extra learnable key and value matrices for the text + encoder. + """ + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> paddle.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states + hidden_states = hidden_states.view( + hidden_states.shape[0], hidden_states.shape[1], -1 + ).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim( + encoder_hidden_states_key_proj + ) + encoder_hidden_states_value_proj = attn.head_to_batch_dim( + encoder_hidden_states_value_proj + ) + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = paddle.cat([encoder_hidden_states_key_proj, key], dim=1) + value = paddle.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = paddle.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + +class AttnAddedKVProcessor2_0: + """ + Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra + learnable key and value matrices for the text encoder. + """ + + def __init__(self): +>>>>>> if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> paddle.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states + hidden_states = hidden_states.view( + hidden_states.shape[0], hidden_states.shape[1], -1 + ).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size, out_dim=4 + ) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query, out_dim=4) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim( + encoder_hidden_states_key_proj, out_dim=4 + ) + encoder_hidden_states_value_proj = attn.head_to_batch_dim( + encoder_hidden_states_value_proj, out_dim=4 + ) + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key, out_dim=4) + value = attn.head_to_batch_dim(value, out_dim=4) + key = paddle.cat([encoder_hidden_states_key_proj, key], dim=2) + value = paddle.cat([encoder_hidden_states_value_proj, value], dim=2) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + hidden_states = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, residual.shape[1] + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + +class JointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): +>>>>>> if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.FloatTensor, + encoder_hidden_states: paddle.FloatTensor = None, + attention_mask: Optional[paddle.FloatTensor] = None, + *args, + **kwargs, + ) -> paddle.FloatTensor: + residual = hidden_states + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size = encoder_hidden_states.shape[0] + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + query = paddle.cat([query, encoder_hidden_states_query_proj], dim=1) + key = paddle.cat([key, encoder_hidden_states_key_proj], dim=1) + value = paddle.cat([value, encoder_hidden_states_value_proj], dim=1) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + hidden_states = ( + hidden_states + ) = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + dropout_p=0.0, + is_causal=False, + ).transpose( + [0, 2, 1, 3] + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + return hidden_states, encoder_hidden_states + + +class FusedJointAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): +>>>>>> if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.FloatTensor, + encoder_hidden_states: paddle.FloatTensor = None, + attention_mask: Optional[paddle.FloatTensor] = None, + *args, + **kwargs, + ) -> paddle.FloatTensor: + residual = hidden_states + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size = encoder_hidden_states.shape[0] + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = paddle.compat.split(qkv, split_size, dim=-1) + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = paddle.compat.split(encoder_qkv, split_size, dim=-1) + query = paddle.cat([query, encoder_hidden_states_query_proj], dim=1) + key = paddle.cat([key, encoder_hidden_states_key_proj], dim=1) + value = paddle.cat([value, encoder_hidden_states_value_proj], dim=1) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + hidden_states = ( + hidden_states + ) = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + dropout_p=0.0, + is_causal=False, + ).transpose( + [0, 2, 1, 3] + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + return hidden_states, encoder_hidden_states + + +class XFormersAttnAddedKVProcessor: + """ + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + residual = hidden_states + hidden_states = hidden_states.view( + hidden_states.shape[0], hidden_states.shape[1], -1 + ).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim( + encoder_hidden_states_key_proj + ) + encoder_hidden_states_value_proj = attn.head_to_batch_dim( + encoder_hidden_states_value_proj + ) + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = paddle.cat([encoder_hidden_states_key_proj, key], dim=1) + value = paddle.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + +class XFormersAttnProcessor: + """ + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> paddle.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, key_tokens, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, key_tokens, batch_size + ) + if attention_mask is not None: + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class AttnProcessorNPU: + """ + Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If + fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is + not significant. + + """ + + def __init__(self): + if not is_torch_npu_available(): + raise ImportError( + "AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices." + ) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> paddle.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if query.dtype in (paddle.float16, paddle.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class AttnProcessor2_0: + """ + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + return + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> paddle.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = paddle.transpose(query.view(batch_size, -1, attn.heads, head_dim),perm = [0,2,1]) + key = paddle.transpose(key.view(batch_size, -1, attn.heads, head_dim),perm = [0,2,1]) + value =paddle.transpose( value.view(batch_size, -1, attn.heads, head_dim),perm = [0,2,1]) + hidden_states = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class HunyuanAttnProcessor2_0: + """ + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): +>>>>>> if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + image_rotary_emb: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + if not attn.is_cross_attention: + key = apply_rotary_emb(key, image_rotary_emb) + hidden_states = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class FusedAttnProcessor2_0: + """ + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses + fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. + For cross-attention modules, key and value projection matrices are fused. + + + + This API is currently 🧪 experimental in nature and can change in future. + + + """ + + def __init__(self): +>>>>>> if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + *args, + **kwargs, + ) -> paddle.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = paddle.compat.split(qkv, split_size, dim=-1) + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + query = attn.to_q(hidden_states) + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // 2 + key, value = paddle.compat.split(kv, split_size, dim=-1) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + hidden_states = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class CustomDiffusionXFormersAttnProcessor(paddle.nn.Layer): + """ + Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use + as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. + """ + + def __init__( + self, + train_kv: bool = True, + train_q_out: bool = False, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + attention_op: Optional[Callable] = None, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.attention_op = attention_op + if self.train_kv: + self.to_k_custom_diffusion = paddle.nn.Linear( + in_features=cross_attention_dim or hidden_size, + out_features=hidden_size, + bias_attr=False, + ) + self.to_v_custom_diffusion = paddle.nn.Linear( + in_features=cross_attention_dim or hidden_size, + out_features=hidden_size, + bias_attr=False, + ) + if self.train_q_out: + self.to_q_custom_diffusion = paddle.nn.Linear( + in_features=hidden_size, out_features=hidden_size, bias_attr=False + ) + self.to_out_custom_diffusion = paddle.nn.LayerList(sublayers=[]) + self.to_out_custom_diffusion.append( + paddle.nn.Linear( + in_features=hidden_size, + out_features=hidden_size, + bias_attr=out_bias, + ) + ) + self.to_out_custom_diffusion.append(paddle.nn.Dropout(p=dropout)) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) + else: + query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + if self.train_kv: + key = self.to_k_custom_diffusion( + encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype) + ) + value = self.to_v_custom_diffusion( + encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype) + ) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + if crossattn: + detach = paddle.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + if self.train_q_out: + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class CustomDiffusionAttnProcessor2_0(paddle.nn.Layer): + """ + Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled + dot-product attention. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + + def __init__( + self, + train_kv: bool = True, + train_q_out: bool = True, + hidden_size: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + out_bias: bool = True, + dropout: float = 0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + if self.train_kv: + self.to_k_custom_diffusion = paddle.nn.Linear( + in_features=cross_attention_dim or hidden_size, + out_features=hidden_size, + bias_attr=False, + ) + self.to_v_custom_diffusion = paddle.nn.Linear( + in_features=cross_attention_dim or hidden_size, + out_features=hidden_size, + bias_attr=False, + ) + if self.train_q_out: + self.to_q_custom_diffusion = paddle.nn.Linear( + in_features=hidden_size, out_features=hidden_size, bias_attr=False + ) + self.to_out_custom_diffusion = paddle.nn.LayerList(sublayers=[]) + self.to_out_custom_diffusion.append( + paddle.nn.Linear( + in_features=hidden_size, + out_features=hidden_size, + bias_attr=out_bias, + ) + ) + self.to_out_custom_diffusion.append(paddle.nn.Dropout(p=dropout)) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + if self.train_kv: + key = self.to_k_custom_diffusion( + encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype) + ) + value = self.to_v_custom_diffusion( + encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype) + ) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + if crossattn: + detach = paddle.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + inner_dim = hidden_states.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + hidden_states = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + if self.train_q_out: + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class SlicedAttnProcessor: + """ + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size: int): + self.slice_size = slice_size + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + residual = hidden_states + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + batch_size_attention, query_tokens, _ = query.shape + hidden_states = paddle.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), + device=query.place, + dtype=query.dtype, + ) + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = ( + attention_mask[start_idx:end_idx] + if attention_mask is not None + else None + ) + attn_slice = attn.get_attention_scores( + query_slice, key_slice, attn_mask_slice + ) + attn_slice = paddle.bmm(attn_slice, value[start_idx:end_idx]) + hidden_states[start_idx:end_idx] = attn_slice + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class SlicedAttnAddedKVProcessor: + """ + Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + + def __init__(self, slice_size): + self.slice_size = slice_size + + def __call__( + self, + attn: "Attention", + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + hidden_states = hidden_states.view( + hidden_states.shape[0], hidden_states.shape[1], -1 + ).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + dim = query.shape[-1] + query = attn.head_to_batch_dim(query) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim( + encoder_hidden_states_key_proj + ) + encoder_hidden_states_value_proj = attn.head_to_batch_dim( + encoder_hidden_states_value_proj + ) + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = paddle.cat([encoder_hidden_states_key_proj, key], dim=1) + value = paddle.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + batch_size_attention, query_tokens, _ = query.shape + hidden_states = paddle.zeros( + (batch_size_attention, query_tokens, dim // attn.heads), + device=query.place, + dtype=query.dtype, + ) + for i in range(batch_size_attention // self.slice_size): + start_idx = i * self.slice_size + end_idx = (i + 1) * self.slice_size + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + attn_mask_slice = ( + attention_mask[start_idx:end_idx] + if attention_mask is not None + else None + ) + attn_slice = attn.get_attention_scores( + query_slice, key_slice, attn_mask_slice + ) + attn_slice = paddle.bmm(attn_slice, value[start_idx:end_idx]) + hidden_states[start_idx:end_idx] = attn_slice + hidden_states = attn.batch_to_head_dim(hidden_states) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + return hidden_states + + +class SpatialNorm(paddle.nn.Layer): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + """ + + def __init__(self, f_channels: int, zq_channels: int): + super().__init__() + self.norm_layer = paddle.nn.GroupNorm( + num_channels=f_channels, + num_groups=32, + epsilon=1e-06, + weight_attr=True, + bias_attr=True, + ) + self.conv_y = paddle.nn.Conv2d( + zq_channels, f_channels, kernel_size=1, stride=1, padding=0 + ) + self.conv_b = paddle.nn.Conv2d( + zq_channels, f_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, f: paddle.Tensor, zq: paddle.Tensor) -> paddle.Tensor: + f_size = f.shape[-2:] + zq = paddle.nn.functional.interpolate(x=zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +class LoRAAttnProcessor(paddle.nn.Layer): + def __init__( + self, + hidden_size: int, + cross_attention_dim: Optional[int] = None, + rank: int = 4, + network_alpha: Optional[int] = None, + **kwargs, + ): + deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`." + deprecate( + "LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False + ) + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + q_rank = kwargs.pop("q_rank", None) + q_hidden_size = kwargs.pop("q_hidden_size", None) + q_rank = q_rank if q_rank is not None else rank + q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size + v_rank = kwargs.pop("v_rank", None) + v_hidden_size = kwargs.pop("v_hidden_size", None) + v_rank = v_rank if v_rank is not None else rank + v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size + out_rank = kwargs.pop("out_rank", None) + out_hidden_size = kwargs.pop("out_hidden_size", None) + out_rank = out_rank if out_rank is not None else rank + out_hidden_size = ( + out_hidden_size if out_hidden_size is not None else hidden_size + ) + self.to_q_lora = LoRALinearLayer( + q_hidden_size, q_hidden_size, q_rank, network_alpha + ) + self.to_k_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank, network_alpha + ) + self.to_v_lora = LoRALinearLayer( + cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha + ) + self.to_out_lora = LoRALinearLayer( + out_hidden_size, out_hidden_size, out_rank, network_alpha + ) + + def __call__( + self, attn: Attention, hidden_states: paddle.Tensor, **kwargs + ) -> paddle.Tensor: + self_cls_name = self.__class__.__name__ + deprecate( + self_cls_name, + "0.26.0", + f"Make sure use {self_cls_name[4:]} instead by settingLoRA layers to `self.{{to_q,to_k,to_v,to_out[0]}}.lora_layer` respectively. This will be done automatically when using `LoraLoaderMixin.load_lora_weights`", + ) + attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.place) + attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.place) + attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.place) + attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.place) + attn._modules.pop("processor") + attn.processor = AttnProcessor() + return attn.processor(attn, hidden_states, **kwargs) + + +class LoRAAttnProcessor2_0(paddle.nn.Layer): + def __init__( + self, + hidden_size: int, + cross_attention_dim: Optional[int] = None, + rank: int = 4, + network_alpha: Optional[int] = None, + **kwargs, + ): + deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`." + deprecate( + "LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False + ) + super().__init__() +>>>>>> if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + q_rank = kwargs.pop("q_rank", None) + q_hidden_size = kwargs.pop("q_hidden_size", None) + q_rank = q_rank if q_rank is not None else rank + q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size + v_rank = kwargs.pop("v_rank", None) + v_hidden_size = kwargs.pop("v_hidden_size", None) + v_rank = v_rank if v_rank is not None else rank + v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size + out_rank = kwargs.pop("out_rank", None) + out_hidden_size = kwargs.pop("out_hidden_size", None) + out_rank = out_rank if out_rank is not None else rank + out_hidden_size = ( + out_hidden_size if out_hidden_size is not None else hidden_size + ) + self.to_q_lora = LoRALinearLayer( + q_hidden_size, q_hidden_size, q_rank, network_alpha + ) + self.to_k_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank, network_alpha + ) + self.to_v_lora = LoRALinearLayer( + cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha + ) + self.to_out_lora = LoRALinearLayer( + out_hidden_size, out_hidden_size, out_rank, network_alpha + ) + + def __call__( + self, attn: Attention, hidden_states: paddle.Tensor, **kwargs + ) -> paddle.Tensor: + self_cls_name = self.__class__.__name__ + deprecate( + self_cls_name, + "0.26.0", + f"Make sure use {self_cls_name[4:]} instead by settingLoRA layers to `self.{{to_q,to_k,to_v,to_out[0]}}.lora_layer` respectively. This will be done automatically when using `LoraLoaderMixin.load_lora_weights`", + ) + attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.place) + attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.place) + attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.place) + attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.place) + attn._modules.pop("processor") + attn.processor = AttnProcessor2_0() + return attn.processor(attn, hidden_states, **kwargs) + + +class LoRAXFormersAttnProcessor(paddle.nn.Layer): + """ + Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + kwargs (`dict`): + Additional keyword arguments to pass to the `LoRALinearLayer` layers. + """ + + def __init__( + self, + hidden_size: int, + cross_attention_dim: int, + rank: int = 4, + attention_op: Optional[Callable] = None, + network_alpha: Optional[int] = None, + **kwargs, + ): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.attention_op = attention_op + q_rank = kwargs.pop("q_rank", None) + q_hidden_size = kwargs.pop("q_hidden_size", None) + q_rank = q_rank if q_rank is not None else rank + q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size + v_rank = kwargs.pop("v_rank", None) + v_hidden_size = kwargs.pop("v_hidden_size", None) + v_rank = v_rank if v_rank is not None else rank + v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size + out_rank = kwargs.pop("out_rank", None) + out_hidden_size = kwargs.pop("out_hidden_size", None) + out_rank = out_rank if out_rank is not None else rank + out_hidden_size = ( + out_hidden_size if out_hidden_size is not None else hidden_size + ) + self.to_q_lora = LoRALinearLayer( + q_hidden_size, q_hidden_size, q_rank, network_alpha + ) + self.to_k_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank, network_alpha + ) + self.to_v_lora = LoRALinearLayer( + cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha + ) + self.to_out_lora = LoRALinearLayer( + out_hidden_size, out_hidden_size, out_rank, network_alpha + ) + + def __call__( + self, attn: Attention, hidden_states: paddle.Tensor, **kwargs + ) -> paddle.Tensor: + self_cls_name = self.__class__.__name__ + deprecate( + self_cls_name, + "0.26.0", + f"Make sure use {self_cls_name[4:]} instead by settingLoRA layers to `self.{{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}}.lora_layer` respectively. This will be done automatically when using `LoraLoaderMixin.load_lora_weights`", + ) + attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.place) + attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.place) + attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.place) + attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.place) + attn._modules.pop("processor") + attn.processor = XFormersAttnProcessor() + return attn.processor(attn, hidden_states, **kwargs) + + +class LoRAAttnAddedKVProcessor(paddle.nn.Layer): + """ + Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text + encoder. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + kwargs (`dict`): + Additional keyword arguments to pass to the `LoRALinearLayer` layers. + """ + + def __init__( + self, + hidden_size: int, + cross_attention_dim: Optional[int] = None, + rank: int = 4, + network_alpha: Optional[int] = None, + ): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.add_k_proj_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank, network_alpha + ) + self.add_v_proj_lora = LoRALinearLayer( + cross_attention_dim or hidden_size, hidden_size, rank, network_alpha + ) + self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer( + hidden_size, hidden_size, rank, network_alpha + ) + + def __call__( + self, attn: Attention, hidden_states: paddle.Tensor, **kwargs + ) -> paddle.Tensor: + self_cls_name = self.__class__.__name__ + deprecate( + self_cls_name, + "0.26.0", + f"Make sure use {self_cls_name[4:]} instead by settingLoRA layers to `self.{{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}}.lora_layer` respectively. This will be done automatically when using `LoraLoaderMixin.load_lora_weights`", + ) + attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.place) + attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.place) + attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.place) + attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.place) + attn._modules.pop("processor") + attn.processor = AttnAddedKVProcessor() + return attn.processor(attn, hidden_states, **kwargs) + + +class IPAdapterAttnProcessor(paddle.nn.Layer): + """ + Attention processor for Multiple IP-Adapters. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): + The context length of the image features. + scale (`float` or List[`float`], defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__( + self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0 + ): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + self.num_tokens = num_tokens + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError( + "`scale` should be a list of integers with the same length as `num_tokens`." + ) + self.scale = scale + self.to_k_ip = paddle.nn.LayerList( + sublayers=[ + paddle.nn.Linear( + in_features=cross_attention_dim, + out_features=hidden_size, + bias_attr=False, + ) + for _ in range(len(num_tokens)) + ] + ) + self.to_v_ip = paddle.nn.LayerList( + sublayers=[ + paddle.nn.Linear( + in_features=cross_attention_dim, + out_features=hidden_size, + bias_attr=False, + ) + for _ in range(len(num_tokens)) + ] + ) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[paddle.Tensor] = None, + ): + residual = hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release. Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." + deprecate( + "encoder_hidden_states not a tuple", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = paddle.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match length of self.scale array ({len(self.scale)}) and number of ip_hidden_states ({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate( + zip(ip_adapter_masks, self.scale, ip_hidden_states) + ): + if not isinstance(mask, paddle.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape [1, num_images_for_ip_adapter, height, width]. Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + ip_attention_probs = attn.get_attention_scores( + query, ip_key, None + ) + _current_ip_hidden_states = paddle.bmm( + ip_attention_probs, ip_value + ) + _current_ip_hidden_states = attn.batch_to_head_dim( + _current_ip_hidden_states + ) + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + mask_downsample = mask_downsample.to( + dtype=query.dtype, device=query.place + ) + hidden_states = hidden_states + scale[i] * ( + _current_ip_hidden_states * mask_downsample + ) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = paddle.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim( + current_ip_hidden_states + ) + hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +class IPAdapterAttnProcessor2_0(paddle.nn.Layer): + """ + Attention processor for IP-Adapter for PyTorch 2.0. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): + The context length of the image features. + scale (`float` or `List[float]`, defaults to 1.0): + the weight scale of image prompt. + """ + + def __init__( + self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0 + ): + super().__init__() +>>>>>> if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + if not isinstance(num_tokens, (tuple, list)): + num_tokens = [num_tokens] + self.num_tokens = num_tokens + if not isinstance(scale, list): + scale = [scale] * len(num_tokens) + if len(scale) != len(num_tokens): + raise ValueError( + "`scale` should be a list of integers with the same length as `num_tokens`." + ) + self.scale = scale + self.to_k_ip = paddle.nn.LayerList( + sublayers=[ + paddle.nn.Linear( + in_features=cross_attention_dim, + out_features=hidden_size, + bias_attr=False, + ) + for _ in range(len(num_tokens)) + ] + ) + self.to_v_ip = paddle.nn.LayerList( + sublayers=[ + paddle.nn.Linear( + in_features=cross_attention_dim, + out_features=hidden_size, + bias_attr=False, + ) + for _ in range(len(num_tokens)) + ] + ) + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + temb: Optional[paddle.Tensor] = None, + scale: float = 1.0, + ip_adapter_masks: Optional[paddle.Tensor] = None, + ): + residual = hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release. Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." + deprecate( + "encoder_hidden_states not a tuple", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + query = attn.to_q(hidden_states) + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + hidden_states = paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match length of self.scale array ({len(self.scale)}) and number of ip_hidden_states ({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate( + zip(ip_adapter_masks, self.scale, ip_hidden_states) + ): + if not isinstance(mask, paddle.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape [1, num_images_for_ip_adapter, height, width]. Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match number of scales ({len(scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(self.scale) + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(scale, list): + if all(s == 0 for s in scale): + skip = True + elif scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(scale, list): + scale = [scale] * mask.shape[1] + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + ip_key = ip_key.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + ip_value = ip_value.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + _current_ip_hidden_states = ( + paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + ip_key.transpose([0, 2, 1, 3]), + ip_value.transpose([0, 2, 1, 3]), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + ) + _current_ip_hidden_states = _current_ip_hidden_states.transpose( + 1, 2 + ).reshape(batch_size, -1, attn.heads * head_dim) + _current_ip_hidden_states = _current_ip_hidden_states.to( + query.dtype + ) + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + mask_downsample = mask_downsample.to( + dtype=query.dtype, device=query.place + ) + hidden_states = hidden_states + scale[i] * ( + _current_ip_hidden_states * mask_downsample + ) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + ip_key = ip_key.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + ip_value = ip_value.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + current_ip_hidden_states = ( + paddle.nn.functional.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + ip_key.transpose([0, 2, 1, 3]), + ip_value.transpose([0, 2, 1, 3]), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ).transpose([0, 2, 1, 3]) + ) + current_ip_hidden_states = current_ip_hidden_states.transpose( + 1, 2 + ).reshape(batch_size, -1, attn.heads * head_dim) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if attn.residual_connection: + hidden_states = hidden_states + residual + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states + + +LORA_ATTENTION_PROCESSORS = ( + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, +) +ADDED_KV_ATTENTION_PROCESSORS = ( + AttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, +) +CROSS_ATTENTION_PROCESSORS = ( + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, + SlicedAttnProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) +AttentionProcessor = Union[ + AttnProcessor, + AttnProcessor2_0, + FusedAttnProcessor2_0, + XFormersAttnProcessor, + SlicedAttnProcessor, + AttnAddedKVProcessor, + SlicedAttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, + CustomDiffusionAttnProcessor, + CustomDiffusionXFormersAttnProcessor, + CustomDiffusionAttnProcessor2_0, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, +] diff --git a/paddlespeech/t2s/modules/flow/convolution.py b/paddlespeech/t2s/modules/flow/convolution.py new file mode 100644 index 000000000..9f38479d8 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/convolution.py @@ -0,0 +1,99 @@ +import paddle + +"""ConvolutionModule definition.""" +from typing import Tuple + + +class ConvolutionModule(paddle.nn.Layer): + """ConvolutionModule in Conformer model.""" + + def __init__( + self, + channels: int, + kernel_size: int = 15, + activation: paddle.nn.Layer = paddle.nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True, + ): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + super().__init__() + self.pointwise_conv1 = paddle.nn.Conv1d( + channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias + ) + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = paddle.nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + assert norm in ["batch_norm", "layer_norm"] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = paddle.nn.BatchNorm1D(num_features=channels) + else: + self.use_layer_norm = True + self.norm = paddle.nn.LayerNorm(normalized_shape=channels) + self.pointwise_conv2 = paddle.nn.Conv1d( + channels, channels, kernel_size=1, stride=1, padding=0, bias=bias + ) + self.activation = activation + + def forward( + self, + x: paddle.Tensor, + mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + cache: paddle.Tensor = paddle.zeros((0, 0, 0)), + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + x = x.transpose(1, 2) + if mask_pad.size(2) > 0: + x.masked_fill_(~mask_pad, 0.0) + if self.lorder > 0: + if cache.size(2) == 0: + x = paddle.compat.pad(x, (self.lorder, 0), "constant", 0.0) + else: + assert cache.size(0) == x.size(0) + assert cache.size(1) == x.size(1) + x = paddle.cat((cache, x), dim=2) + assert x.size(2) > self.lorder + new_cache = x[:, :, -self.lorder :] + else: + new_cache = paddle.zeros((0, 0, 0), dtype=x.dtype, device=x.place) + x = self.pointwise_conv1(x) + x = paddle.nn.functional.glu(x=x, axis=1) + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + if mask_pad.size(2) > 0: + x.masked_fill_(~mask_pad, 0.0) + return x.transpose(1, 2), new_cache diff --git a/paddlespeech/t2s/modules/flow/decoder.py b/paddlespeech/t2s/modules/flow/decoder.py new file mode 100644 index 000000000..0489317e7 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/decoder.py @@ -0,0 +1,762 @@ +# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Any, Dict, Optional +import paddle +import math +from paddle import nn +import paddle.nn.functional as F +from einops import pack, rearrange, repeat +from paddlespeech.t2s.models.CosyVoice.common import mask_to_bias +from paddlespeech.t2s.models.CosyVoice.mask import add_optional_chunk_mask +from .matcha_transformer import BasicTransformerBlock + +def get_activation(act_fn): + if act_fn == "silu": + return nn.Silu() + elif act_fn == "mish": + return nn.Mish() + elif act_fn == "relu": + return nn.ReLU() + elif act_fn == "gelu": + return nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + +class Block1D(nn.Layer): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = nn.Sequential( + nn.Conv1D(dim, dim_out, 3, padding=1), + nn.GroupNorm(groups, dim_out), + nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + +class ResnetBlock1D(nn.Layer): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = nn.Sequential( + nn.Mish(), + nn.Linear(time_emb_dim, dim_out) + ) + + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv1D(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + # 添加时间嵌入并调整维度 + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + +class Downsample1D(nn.Layer): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1D(dim, dim, 3, stride=2, padding=1) + + def forward(self, x): + return self.conv(x) + +class TimestepEmbedding(nn.Layer): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None and self.cond_proj is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + if self.act is not None: + sample = self.act(sample) + sample = self.linear_2(sample) + if self.post_act is not None: + sample = self.post_act(sample) + return sample + +class Upsample1D(nn.Layer): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.Conv1DTranspose(channels, self.out_channels, 4, stride=2, padding=1) + elif use_conv: + self.conv = nn.Conv1D(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") + + if self.use_conv: + outputs = self.conv(outputs) + + return outputs + +class Transpose(nn.Layer): + def __init__(self, dim0: int, dim1: int): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + x = paddle.transpose(x, [0, self.dim1, self.dim0]) + return x + +class CausalConv1d(nn.Conv1D): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + padding_mode: str = 'zeros' + ) -> None: + super(CausalConv1d, self).__init__(in_channels, out_channels, + kernel_size, stride, + padding=0, dilation=dilation, + groups=groups, + padding_mode=padding_mode) + assert stride == 1 + self.causal_padding = kernel_size - 1 + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + x = F.pad(x, (self.causal_padding, 0), value=0.0) + x = super(CausalConv1d, self).forward(x) + return x + +class SinusoidalPosEmb(paddle.nn.Layer): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = paddle.exp(paddle.arange(half_dim).astype('float32') * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1) + return emb + +class CausalBlock1D(Block1D): + def __init__(self, dim: int, dim_out: int): + super(CausalBlock1D, self).__init__(dim, dim_out) + self.block = nn.Sequential( + CausalConv1d(dim, dim_out, 3), + Transpose(1, 2), + nn.LayerNorm(dim_out), + Transpose(1, 2), + nn.Mish() + ) + + def forward(self, x: paddle.Tensor, mask: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: + output = self.block(x * mask) + return output * mask + + +class CausalResnetBlock1D(ResnetBlock1D): + def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): + super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) + self.block1 = CausalBlock1D(dim, dim_out) + self.block2 = CausalBlock1D(dim_out, dim_out) + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, +) -> paddle.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + + Returns: + paddle.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + pos_idx = paddle.arange(size, dtype='int64') + block_value = (paddle.floor_divide(pos_idx, chunk_size) + 1) * chunk_size + ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) + + return ret + + +def add_optional_chunk_mask(xs: paddle.Tensor, + masks: paddle.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int, + enable_full_context: bool = True): + """ Apply optional mask for encoder. + + Args: + xs (paddle.Tensor): padded input, (B, L, D), L for max length + mask (paddle.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + enable_full_context (bool): + True: chunk size is either [1, 25] or full context(max_len) + False: chunk size ~ U[1, 25] + + Returns: + paddle.Tensor: chunk mask of the input xs. + """ + # Whether to use chunk mask or not + if use_dynamic_chunk: + max_len = xs.shape[1] + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + # chunk size is either [1, 25] or full context(max_len). + # Since we use 4 times subsampling and allow up to 1s(100 frames) + # delay, the maximum frame is 100 / 4 = 25. + chunk_size = paddle.randint(1, max_len, shape=(1,)).item() + num_left_chunks = -1 + if chunk_size > max_len // 2 and enable_full_context: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = paddle.randint(0, max_left_chunks, shape=(1,)).item() + + chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size, + num_left_chunks) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size, + num_left_chunks) # (L, L) + chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) + chunk_masks = masks & chunk_masks # (B, L, L) + else: + chunk_masks = masks + assert chunk_masks.dtype == paddle.bool + if (chunk_masks.sum(axis=-1) == 0).sum().item() != 0: + print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in future computation!') + all_false_mask = chunk_masks.sum(axis=-1) == 0 + chunk_masks = paddle.where(all_false_mask.unsqueeze(-1), paddle.ones_like(chunk_masks, dtype='bool'), chunk_masks) + + return chunk_masks + +def mask_to_bias(mask: paddle.Tensor, dtype: str) -> paddle.Tensor: + assert mask.dtype == paddle.bool, "Input mask must be of boolean type" + assert dtype in [paddle.float32, paddle.bfloat16, paddle.float16], f"Unsupported dtype: {dtype}" + mask = mask.astype(dtype) + mask = (1.0 - mask) * -1.0e+10 + + return mask + +class ConditionalDecoder(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.down_blocks = nn.LayerList([]) + self.mid_blocks = nn.LayerList([]) + self.up_blocks = nn.LayerList([]) + + output_channel = in_channels + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.LayerList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else nn.Conv1D(output_channel, output_channel, 3, padding=1) + ) + self.down_blocks.append(nn.LayerList([resnet, transformer_blocks, downsample])) + + for _ in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.LayerList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.LayerList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = ResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.LayerList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else nn.Conv1D(output_channel, output_channel, 3, padding=1) + ) + self.up_blocks.append(nn.LayerList([resnet, transformer_blocks, upsample])) + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = nn.Conv1D(channels[-1], self.out_channels, 1) + self.initialize_weights() + + def initialize_weights(self): + for m in self.sublayers(): + if isinstance(m, nn.Conv1D): + nn.initializer.KaimingNormal(m.weight, nonlinearity='relu') + if m.bias is not None: + nn.initializer.Constant(m.bias, value=0) + elif isinstance(m, nn.GroupNorm): + nn.initializer.Constant(m.weight, value=1) + nn.initializer.Constant(m.bias, value=0) + elif isinstance(m, nn.Linear): + nn.initializer.KaimingNormal(m.weight, nonlinearity='relu') + if m.bias is not None: + nn.initializer.Constant(m.bias, value=0) + + def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): + """Forward pass of the UNet1DConditional model. + + Args: + x (paddle.Tensor): shape (batch_size, in_channels, time) + mask (paddle.Tensor): shape (batch_size, 1, time) + t (paddle.Tensor): shape (batch_size) + spks (paddle.Tensor, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (paddle.Tensor, optional): placeholder for future use. Defaults to None. + + Returns: + paddle.Tensor: output tensor + """ + + t = self.time_embeddings(t).astype(t.dtype) + t = self.time_mlp(t) + + x = pack([x, mu], "b * t")[0] + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = add_optional_chunk_mask(x, mask_down.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = add_optional_chunk_mask(x, mask_mid.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + attn_mask = add_optional_chunk_mask(x, mask_up.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask + + +class CausalConditionalDecoder(nn.Layer): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + static_chunk_size=50, + num_decoding_left_chunks=2, + ): + """ + This decoder requires an input with the same shape of the target. So, if your text content + is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. + """ + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, + time_embed_dim=time_embed_dim, + act_fn="silu", + ) + self.static_chunk_size = static_chunk_size + self.num_decoding_left_chunks = num_decoding_left_chunks + self.down_blocks = nn.LayerList([]) + self.mid_blocks = nn.LayerList([]) + self.up_blocks = nn.LayerList([]) + + output_channel = in_channels + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + transformer_blocks = nn.LayerList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3) # 假设已实现 + ) + self.down_blocks.append(nn.LayerList([resnet, transformer_blocks, downsample])) + + for _ in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) + + transformer_blocks = nn.LayerList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + + self.mid_blocks.append(nn.LayerList([resnet, transformer_blocks])) + + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] * 2 + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = CausalResnetBlock1D( + dim=input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = nn.LayerList( + [ + BasicTransformerBlock( + dim=output_channel, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) # 假设已实现 + if not is_last + else CausalConv1d(output_channel, output_channel, 3) + ) + self.up_blocks.append(nn.LayerList([resnet, transformer_blocks, upsample])) + self.final_block = CausalBlock1D(channels[-1], channels[-1]) # 假设已实现 + self.final_proj = nn.Conv1D(channels[-1], self.out_channels, 1) # 使用 Conv1D + self.initialize_weights() + + def initialize_weights(self): + for m in self.sublayers(): # 使用 sublayers() 而不是 modules() + if isinstance(m, nn.Conv1D): + nn.initializer.KaimingNormal(m.weight, nonlinearity='relu') + if m.bias is not None: + initializer = nn.initializer.Constant(value=0) + initializer(m.bias) + elif isinstance(m, nn.GroupNorm): + nn.initializer.Constant(m.weight, value=1) + nn.initializer.Constant(m.bias, value=0) + elif isinstance(m, nn.Linear): + nn.initializer.KaimingNormal(m.weight, nonlinearity='relu') + if m.bias is not None: + initializer = nn.initializer.Constant(value=0) + initializer(m.bias) + def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False): + """Forward pass of the UNet1DConditional model. + + Args: + x (paddle.Tensor): shape (batch_size, in_channels, time) + mask (paddle.Tensor): shape (batch_size, 1, time) + mu (paddle.Tensor): mean tensor for conditioning + t (paddle.Tensor): shape (batch_size) + spks (paddle.Tensor, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (paddle.Tensor, optional): placeholder for future use. Defaults to None. + streaming (bool, optional): whether to use streaming mode. Defaults to False. + + Returns: + paddle.Tensor: output tensor + """ + t = self.time_embeddings(t).astype(t.dtype) # 使用 astype 代替 .to(t.dtype) + t = self.time_mlp(t) + x = pack([x, mu], "b * t")[0] # 假设 pack 函数已实现 + if spks is not None: + spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) # 假设 repeat 函数已实现 + x = pack([x, spks], "b * t")[0] + if cond is not None: + x = pack([x, cond], "b * t")[0] + + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + + x = rearrange(x, "b c t -> b t c").contiguous() # 假设 rearrange 函数已实现 + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_down.astype('bool'), False, False, 0, self.static_chunk_size, -1) # 使用 astype('bool') + else: + attn_mask = add_optional_chunk_mask(x, mask_down.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1) # 使用 .shape 而不是 .size() + attn_mask = mask_to_bias(attn_mask, x.dtype) # 假设 mask_to_bias 函数已实现 + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + + x = rearrange(x, "b t c -> b c t").contiguous() + hiddens.append(x) # Save hidden states for skip connections + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = rearrange(x, "b c t -> b t c").contiguous() + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_mid.astype('bool'), False, False, 0, self.static_chunk_size, -1) + else: + attn_mask = add_optional_chunk_mask(x, mask_mid.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + skip = hiddens.pop() + x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] + x = resnet(x, mask_up, t) + x = rearrange(x, "b c t -> b t c").contiguous() + if streaming is True: + attn_mask = add_optional_chunk_mask(x, mask_up.astype('bool'), False, False, 0, self.static_chunk_size, -1) + else: + attn_mask = add_optional_chunk_mask(x, mask_up.astype('bool'), False, False, 0, 0, -1).repeat(1, x.shape[1], 1) + attn_mask = mask_to_bias(attn_mask, x.dtype) + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, + attention_mask=attn_mask, + timestep=t, + ) + x = rearrange(x, "b t c -> b c t").contiguous() + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask diff --git a/paddlespeech/t2s/modules/flow/diffusers_activatioins.py b/paddlespeech/t2s/modules/flow/diffusers_activatioins.py new file mode 100644 index 000000000..031930167 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/diffusers_activatioins.py @@ -0,0 +1,132 @@ +import paddle + +from ..utils import deprecate +from ..utils.import_utils import is_torch_npu_available + +if is_torch_npu_available(): + import torch_npu +ACTIVATION_FUNCTIONS = { + "swish": paddle.nn.SiLU(), + "silu": paddle.nn.SiLU(), + "mish": paddle.nn.Mish(), + "gelu": paddle.nn.GELU(), + "relu": paddle.nn.ReLU(), +} + + +def get_activation(act_fn: str) -> paddle.nn.Layer: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + + +class FP32SiLU(paddle.nn.Layer): + """ + SiLU activation function with input upcasted to torch.float32. + """ + + def __init__(self): + super().__init__() + + def forward(self, inputs: paddle.Tensor) -> paddle.Tensor: + return paddle.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class GELU(paddle.nn.Layer): + """ + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True + ): + super().__init__() + self.proj = paddle.nn.Linear( + in_features=dim_in, out_features=dim_out, bias_attr=bias + ) + self.approximate = approximate + + def gelu(self, gate: paddle.Tensor) -> paddle.Tensor: + if gate.device.type != "mps": + return paddle.nn.functional.gelu(gate, approximate=self.approximate) + return paddle.nn.functional.gelu( + gate.to(dtype=paddle.float32), approximate=self.approximate + ).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(paddle.nn.Layer): + """ + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = paddle.nn.Linear( + in_features=dim_in, out_features=dim_out * 2, bias_attr=bias + ) + + def gelu(self, gate: paddle.Tensor) -> paddle.Tensor: + if gate.device.type != "mps": + return paddle.nn.functional.gelu(gate) + return paddle.nn.functional.gelu(gate.to(dtype=paddle.float32)).to( + dtype=gate.dtype + ) + + def forward(self, hidden_states, *args, **kwargs): + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + hidden_states = self.proj(hidden_states) + if is_torch_npu_available(): + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + else: + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(paddle.nn.Layer): + """ + The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this + [paper](https://arxiv.org/abs/1606.08415). + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = paddle.nn.Linear( + in_features=dim_in, out_features=dim_out, bias_attr=bias + ) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + x = self.proj(x) + return x * paddle.nn.functional.sigmoid(1.702 * x) diff --git a/paddlespeech/t2s/modules/flow/encoder_layer.py b/paddlespeech/t2s/modules/flow/encoder_layer.py new file mode 100644 index 000000000..d7debadf9 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/encoder_layer.py @@ -0,0 +1,205 @@ +import paddle + +"""Encoder self-attention layer definition.""" +from typing import Optional, Tuple + + +class TransformerEncoderLayer(paddle.nn.Layer): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: to use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: paddle.nn.Layer, + feed_forward: paddle.nn.Layer, + dropout_rate: float, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + self.norm2 = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + self.dropout = paddle.nn.Dropout(p=dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: paddle.Tensor, + mask: paddle.Tensor, + pos_emb: paddle.Tensor, + mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + att_cache: paddle.Tensor = paddle.zeros((0, 0, 0, 0)), + cnn_cache: paddle.Tensor = paddle.zeros((0, 0, 0, 0)), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): just for interface compatibility + to ConformerEncoderLayer + mask_pad (torch.Tensor): does not used in transformer layer, + just for unified api with conformer. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2), not used here, it's for interface + compatibility to ConformerEncoderLayer. + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). + + """ + residual = x + if self.normalize_before: + x = self.norm1(x) + x_att, new_att_cache = self.self_attn( + x, x, x, mask, pos_emb=pos_emb, cache=att_cache + ) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm1(x) + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + fake_cnn_cache = paddle.zeros((0, 0, 0), dtype=x.dtype, device=x.place) + return x, mask, new_att_cache, fake_cnn_cache + + +class ConformerEncoderLayer(paddle.nn.Layer): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: paddle.nn.Layer, + feed_forward: Optional[paddle.nn.Layer] = None, + feed_forward_macaron: Optional[paddle.nn.Layer] = None, + conv_module: Optional[paddle.nn.Layer] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + self.norm_mha = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + if feed_forward_macaron is not None: + self.norm_ff_macaron = paddle.nn.LayerNorm( + normalized_shape=size, epsilon=1e-12 + ) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + self.norm_final = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + self.dropout = paddle.nn.Dropout(p=dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: paddle.Tensor, + mask: paddle.Tensor, + pos_emb: paddle.Tensor, + mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + att_cache: paddle.Tensor = paddle.zeros((0, 0, 0, 0)), + cnn_cache: paddle.Tensor = paddle.zeros((0, 0, 0, 0)), + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + residual = x + if self.normalize_before: + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + new_cnn_cache = paddle.zeros((0, 0, 0), dtype=x.dtype, device=x.place) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.norm_conv(x) + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + if self.conv_module is not None: + x = self.norm_final(x) + return x, mask, new_att_cache, new_cnn_cache diff --git a/paddlespeech/t2s/modules/flow/flow.py b/paddlespeech/t2s/modules/flow/flow.py new file mode 100644 index 000000000..2c1517f78 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/flow.py @@ -0,0 +1,326 @@ +import logging +import random +from typing import Dict, Optional + +import paddle +from omegaconf import DictConfig + +from paddlespeech.t2s.models.CosyVoice.mask import make_pad_mask + + +class MaskedDiffWithXvec(paddle.nn.Layer): + def __init__( + self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 4096, + input_frame_rate: int = 50, + only_mask_loss: bool = True, + encoder: paddle.nn.Layer = None, + length_regulator: paddle.nn.Layer = None, + decoder: paddle.nn.Layer = None, + decoder_conf: Dict = { + "in_channels": 240, + "out_channel": 80, + "spk_emb_dim": 80, + "n_spks": 1, + "cfm_params": DictConfig( + { + "sigma_min": 1e-06, + "solver": "euler", + "t_scheduler": "cosine", + "training_cfg_rate": 0.2, + "inference_cfg_rate": 0.7, + "reg_loss_type": "l1", + } + ), + "decoder_params": { + "channels": [256, 256], + "dropout": 0.0, + "attention_head_dim": 64, + "n_blocks": 4, + "num_mid_blocks": 12, + "num_heads": 8, + "act_fn": "gelu", + }, + }, + mel_feat_conf: Dict = { + "n_fft": 1024, + "num_mels": 80, + "sampling_rate": 22050, + "hop_size": 256, + "win_size": 1024, + "fmin": 0, + "fmax": 8000, + }, + ): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = paddle.nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = paddle.nn.Linear( + in_features=spk_embed_dim, out_features=output_size + ) + self.encoder = encoder + self.encoder_proj = paddle.nn.Linear( + in_features=self.encoder.output_size(), out_features=output_size + ) + self.decoder = decoder + self.length_regulator = length_regulator + self.only_mask_loss = only_mask_loss + + def forward( + self, batch: dict, device: paddle.device + ) -> Dict[str, Optional[paddle.Tensor]]: + token = batch["speech_token"].to(device) + token_len = batch["speech_token_len"].to(device) + feat = batch["speech_feat"].to(device) + feat_len = batch["speech_feat_len"].to(device) + embedding = batch["embedding"].to(device) + embedding = paddle.nn.functional.normalize(x=embedding, axis=1) + embedding = self.spk_embed_affine_layer(embedding) + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) + token = self.input_embedding(paddle.clip(token, min=0)) * mask + h, h_lengths = self.encoder(token, token_len) + h = self.encoder_proj(h) + h, h_lengths = self.length_regulator(h, feat_len) + conds = paddle.zeros(feat.shape, device=token.place) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] + conds = paddle.transpose(conds, perm=[0, 2, 1]) + mask = (~make_pad_mask(feat_len)).to(h) + loss, _ = self.decoder.compute_loss( + paddle.transpose(feat, perm=[0, 2, 1]), + mask.unsqueeze(1), + paddle.transpose(h, perm=[0, 2, 1]), + embedding, + cond=conds, + ) + return {"loss": loss} + + @paddle.no_grad() + def inference( + self, + token, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + embedding, + flow_cache, + ): + assert token.shape[0] == 1 + embedding = paddle.nn.functional.normalize(x=embedding, axis=1) + embedding = self.spk_embed_affine_layer(embedding) + token_len1, token_len2 = prompt_token.shape[1], token.shape[1] + token, token_len = ( + paddle.cat([prompt_token, token], dim=1), + prompt_token_len + token_len, + ) + mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) + token = self.input_embedding(paddle.clip(token, min=0)) * mask + h, h_lengths = self.encoder(token, token_len) + h = self.encoder_proj(h) + mel_len1, mel_len2 = prompt_feat.shape[1], int( + token_len2 / self.input_frame_rate * 22050 / 256 + ) + h, h_lengths = self.length_regulator.inference( + h[:, :token_len1], + h[:, token_len1:], + mel_len1, + mel_len2, + self.input_frame_rate, + ) + conds = paddle.zeros( + [1, mel_len1 + mel_len2, self.output_size], device=token.place + ).to(h.dtype) + conds[:, :mel_len1] = prompt_feat + conds = paddle.transpose(conds, perm=[0, 2, 1]) + + mask = (~make_pad_mask(paddle.tensor([mel_len1 + mel_len2]))).to(h) + feat, flow_cache = self.decoder( + mu=paddle.transpose(h, perm=[0, 2, 1]), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=10, + prompt_len=mel_len1, + cache=flow_cache, + ) + feat = feat[:, :, mel_len1:] + assert feat.shape[2] == mel_len2 + return feat.float(), flow_cache + + +class CausalMaskedDiffWithXvec(paddle.nn.Layer): + def __init__( + self, + input_size: int = 512, + output_size: int = 80, + spk_embed_dim: int = 192, + output_type: str = "mel", + vocab_size: int = 6561, + input_frame_rate: int = 25, + only_mask_loss: bool = True, + token_mel_ratio: int = 2, + pre_lookahead_len: int = 3, + encoder: paddle.nn.Layer = None, + decoder: paddle.nn.Layer = None, + decoder_conf: Dict = { + "in_channels": 240, + "out_channel": 80, + "spk_emb_dim": 80, + "n_spks": 1, + "cfm_params": DictConfig( + { + "sigma_min": 1e-06, + "solver": "euler", + "t_scheduler": "cosine", + "training_cfg_rate": 0.2, + "inference_cfg_rate": 0.7, + "reg_loss_type": "l1", + } + ), + "decoder_params": { + "channels": [256, 256], + "dropout": 0.0, + "attention_head_dim": 64, + "n_blocks": 4, + "num_mid_blocks": 12, + "num_heads": 8, + "act_fn": "gelu", + }, + }, + mel_feat_conf: Dict = { + "n_fft": 1024, + "num_mels": 80, + "sampling_rate": 22050, + "hop_size": 256, + "win_size": 1024, + "fmin": 0, + "fmax": 8000, + }, + ): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.decoder_conf = decoder_conf + self.mel_feat_conf = mel_feat_conf + self.vocab_size = vocab_size + self.output_type = output_type + self.input_frame_rate = input_frame_rate + logging.info(f"input frame rate={self.input_frame_rate}") + self.input_embedding = paddle.nn.Embedding(vocab_size, input_size) + self.spk_embed_affine_layer = paddle.nn.Linear( + in_features=spk_embed_dim, out_features=output_size + ) + self.encoder = encoder + self.encoder_proj = paddle.nn.Linear( + in_features=self.encoder.output_size(), out_features=output_size + ) + self.decoder = decoder + self.only_mask_loss = only_mask_loss + self.token_mel_ratio = token_mel_ratio + self.pre_lookahead_len = pre_lookahead_len + + def forward( + self, batch: dict, device: paddle.device + ) -> Dict[str, Optional[paddle.Tensor]]: + token = batch["speech_token"].to(device) + token_len = batch["speech_token_len"].to(device) + feat = batch["speech_feat"].to(device) + feat_len = batch["speech_feat_len"].to(device) + embedding = batch["embedding"].to(device) + streaming = True if random.random() < 0.5 else False + embedding = paddle.nn.functional.normalize(x=embedding, axis=1) + embedding = self.spk_embed_affine_layer(embedding) + mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) + token = self.input_embedding(paddle.clip(token, min=0)) * mask + h, h_lengths = self.encoder(token, token_len, streaming=streaming) + h = self.encoder_proj(h) + conds = paddle.zeros(feat.shape, device=token.place) + for i, j in enumerate(feat_len): + if random.random() < 0.5: + continue + index = random.randint(0, int(0.3 * j)) + conds[i, :index] = feat[i, :index] + conds = paddle.transpose(conds, perm=[0, 2, 1]) + + mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h) + loss, _ = self.decoder.compute_loss( + paddle.transpose(feat, perm=[0, 2, 1]).contiguous(), + mask.unsqueeze(1), + paddle.transpose(h, perm=[0, 2, 1]).contiguous(), + embedding, + cond=conds, + streaming=streaming, + ) + return {"loss": loss} + + @paddle.no_grad() + def inference( + self, + token, + token_len, + prompt_token, + prompt_token_len, + prompt_feat, + prompt_feat_len, + embedding, + streaming, + finalize, + ): + assert token.shape[0] == 1 + embedding = paddle.nn.functional.normalize(x=embedding, axis=1) + embedding = self.spk_embed_affine_layer(embedding) + + token, token_len = ( + paddle.cat([prompt_token, token], dim=1), + prompt_token_len + token_len, + ) + mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) + token = self.input_embedding(paddle.clip(token, min=0)) * mask + if finalize is True: + h, h_lengths = self.encoder(token, token_len, streaming=streaming) + else: + token, context = ( + token[:, : -self.pre_lookahead_len], + token[:, -self.pre_lookahead_len :], + ) + h, h_lengths = self.encoder( + token, token_len, context=context, streaming=streaming + ) + + + mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] + h = self.encoder_proj(h) + conds = paddle.zeros( + [1, mel_len1 + mel_len2, self.output_size] + ).to(h.dtype) + conds[:, :mel_len1] = prompt_feat + conds = paddle.transpose(conds, perm=[0, 2, 1]) + mask = (~make_pad_mask(paddle.to_tensor([mel_len1 + mel_len2],dtype='int32'))).to(h) + feat, _ = self.decoder( + mu=paddle.transpose(h, perm=[0, 2, 1]).contiguous(), + mask=mask.unsqueeze(1), + spks=embedding, + cond=conds, + n_timesteps=10, + streaming=streaming, + ) + + feat = feat[:, :, mel_len1:] + assert feat.shape[2] == mel_len2 + return feat.float(), None diff --git a/paddlespeech/t2s/modules/flow/flow_matching.py b/paddlespeech/t2s/modules/flow/flow_matching.py new file mode 100644 index 000000000..98443535c --- /dev/null +++ b/paddlespeech/t2s/modules/flow/flow_matching.py @@ -0,0 +1,342 @@ +import paddle +from abc import ABC +from paddlespeech.t2s.models.CosyVoice.common import set_all_random_seed + +class BASECFM(paddle.nn.Layer, ABC): + def __init__(self, n_feats, cfm_params, n_spks=1, spk_emb_dim=128): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 0.0001 + self.estimator = None + + @paddle.no_grad() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = paddle.randn(shape=mu.shape, dtype=mu.dtype) * temperature + t_span = paddle.linspace(start=0, stop=1, num=n_timesteps + 1) + return self.solve_euler( + z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond + ) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + sol = [] + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + t = paddle.rand(shape=[b, 1, 1], dtype=mu.dtype) + z = paddle.randn(shape=x1.shape, dtype=x1.dtype) + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + loss = paddle.nn.functional.mse_loss( + input=self.estimator(y, mask, mu, t.squeeze(), spks), + label=u, + reduction="sum", + ) / (paddle.sum(mask) * u.shape[1]) + return loss, y + +class ConditionalCFM(BASECFM): + def __init__( + self, + in_channels, + cfm_params, + n_spks=1, + spk_emb_dim=64, + estimator: paddle.nn.Layer = None, + ): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + self.t_scheduler = cfm_params.t_scheduler + self.training_cfg_rate = cfm_params.training_cfg_rate + self.inference_cfg_rate = cfm_params.inference_cfg_rate + in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) + self.estimator = estimator + + @paddle.no_grad() + def forward( + self, + mu, + mask, + n_timesteps, + temperature=1.0, + spks=None, + cond=None, + prompt_len=0, + cache=paddle.zeros([1, 80, 0, 2]), + ): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = ( + paddle.randn(shape=mu.shape, dtype=mu.dtype).to(mu.place).to(mu.dtype) + * temperature + ) + cache_size = cache.shape[2] + if cache_size != 0: + z[:, :, :cache_size] = cache[:, :, :, 0] + mu[:, :, :cache_size] = cache[:, :, :, 1] + z_cache = paddle.cat([z[:, :, :prompt_len], z[:, :, -34:]], axis=2) + mu_cache = paddle.cat([mu[:, :, :prompt_len], mu[:, :, -34:]], axis=2) + cache = paddle.stack([z_cache, mu_cache], axis=-1) + t_span = paddle.linspace(start=0, stop=1, num=n_timesteps + 1, dtype=mu.dtype) + if self.t_scheduler == "cosine": + t_span = 1 - paddle.cos(t_span * 0.5 * paddle.pi) + return ( + self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), + cache, + ) + + def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t = t.unsqueeze(axis=0) + sol = [] + x_in = paddle.zeros([2, 80, x.shape[2]], dtype=x.dtype) + mask_in = paddle.zeros([2, 1, x.shape[2]], dtype=x.dtype) + mu_in = paddle.zeros([2, 80, x.shape[2]], dtype=x.dtype) + t_in = paddle.zeros([2], dtype=x.dtype) + spks_in = paddle.zeros([2, 80], dtype=x.dtype) + cond_in = paddle.zeros([2, 80, x.shape[2]], dtype=x.dtype) + for step in range(1, len(t_span)): + x_in[:] = x + mask_in[:] = mask + mu_in[0] = mu + t_in[:] = t.unsqueeze(0) + spks_in[0] = spks + cond_in[0] = cond + dphi_dt = self.forward_estimator( + x_in, mask_in, mu_in, t_in, spks_in, cond_in, streaming + ) + dphi_dt, cfg_dphi_dt = paddle.split(dphi_dt, [x.shape[0], x.shape[0]], axis=0) + dphi_dt = ( + 1.0 + self.inference_cfg_rate + ) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + return sol[-1].float() + + def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False): + if isinstance(self.estimator, paddle.nn.Layer): + return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming) + else: + [estimator, stream], trt_engine = self.estimator.acquire_estimator() + paddle.device.current_stream().synchronize() + with stream: + estimator.set_input_shape("x", (2, 80, x.shape[2])) + estimator.set_input_shape("mask", (2, 1, x.shape[2])) + estimator.set_input_shape("mu", (2, 80, x.shape[2])) + estimator.set_input_shape("t", (2,)) + estimator.set_input_shape("spks", (2, 80)) + estimator.set_input_shape("cond", (2, 80, x.shape[2])) + data_ptrs = [ + x.contiguous().data_ptr(), + mask.contiguous().data_ptr(), + mu.contiguous().data_ptr(), + t.contiguous().data_ptr(), + spks.contiguous().data_ptr(), + cond.contiguous().data_ptr(), + x.data_ptr(), + ] + for i, j in enumerate(data_ptrs): + estimator.set_tensor_address(trt_engine.get_tensor_name(i), j) + assert ( + estimator.execute_async_v3( + paddle.device.current_stream().cuda_stream + ) + is True + ) + paddle.device.current_stream().synchronize() + self.estimator.release_estimator(estimator, stream) + return x + + def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + t = paddle.rand(shape=[b, 1, 1], dtype=mu.dtype) + if self.t_scheduler == "cosine": + t = 1 - paddle.cos(t * 0.5 * paddle.pi) + z = paddle.randn(shape=x1.shape, dtype=x1.dtype) + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + if self.training_cfg_rate > 0: + cfg_mask = paddle.rand(shape=b) > self.training_cfg_rate + mu = mu * cfg_mask.view(-1, 1, 1) + spks = spks * cfg_mask.view(-1, 1) + cond = cond * cfg_mask.view(-1, 1, 1) + pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming) + loss = paddle.nn.functional.mse_loss( + input=pred * mask, label=u * mask, reduction="sum" + ) / (paddle.sum(mask) * u.shape[1]) + return loss, y + + +class CausalConditionalCFM(ConditionalCFM): + def __init__( + self, + in_channels, + cfm_params, + n_spks=1, + spk_emb_dim=64, + estimator: paddle.nn.Layer = None, + ): + super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) + set_all_random_seed(42) + self.rand_noise = paddle.randn([1, 80, 50 * 300]) + @paddle.no_grad() + def forward( + self, + mu, + mask, + n_timesteps, + temperature=1.0, + spks=None, + cond=None, + streaming=False, + ): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + + z = self.rand_noise[:, :, : mu.shape[2]].to(mu.place).to(mu.dtype) * temperature + t_span = paddle.linspace(start=0, stop=1, num=n_timesteps + 1, dtype=mu.dtype) + if self.t_scheduler == "cosine": + t_span = 1 - paddle.cos(t_span * 0.5 * paddle.pi) + + return ( + self.solve_euler( + z, + t_span=t_span, + mu=mu, + mask=mask, + spks=spks, + cond=cond, + streaming=streaming, + ), + None, + ) diff --git a/paddlespeech/t2s/modules/flow/flow_matching_matcha.py b/paddlespeech/t2s/modules/flow/flow_matching_matcha.py new file mode 100644 index 000000000..0bcea69e4 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/flow_matching_matcha.py @@ -0,0 +1,124 @@ +from abc import ABC + +import paddle +from matcha.models.components.decoder import Decoder +from matcha.utils.pylogger import get_pylogger + +log = get_pylogger(__name__) + + +class BASECFM(paddle.nn.Layer, ABC): + def __init__(self, n_feats, cfm_params, n_spks=1, spk_emb_dim=128): + super().__init__() + self.n_feats = n_feats + self.n_spks = n_spks + self.spk_emb_dim = spk_emb_dim + self.solver = cfm_params.solver + if hasattr(cfm_params, "sigma_min"): + self.sigma_min = cfm_params.sigma_min + else: + self.sigma_min = 0.0001 + self.estimator = None + + @paddle.no_grad() + def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + z = paddle.randn(shape=mu.shape, dtype=mu.dtype) * temperature + t_span = paddle.linspace(start=0, stop=1, num=n_timesteps + 1) + return self.solve_euler( + z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond + ) + + def solve_euler(self, x, t_span, mu, mask, spks, cond): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): output_mask + shape: (batch_size, 1, mel_timesteps) + spks (torch.Tensor, optional): speaker ids. Defaults to None. + shape: (batch_size, spk_emb_dim) + cond: Not used but kept for future purposes + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + sol = [] + for step in range(1, len(t_span)): + dphi_dt = self.estimator(x, mask, mu, t, spks, cond) + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + return sol[-1] + + def compute_loss(self, x1, mask, mu, spks=None, cond=None): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = mu.shape + t = paddle.rand(shape=[b, 1, 1], dtype=mu.dtype) + z = paddle.randn(shape=x1.shape, dtype=x1.dtype) + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + loss = paddle.nn.functional.mse_loss( + input=self.estimator(y, mask, mu, t.squeeze(), spks), + label=u, + reduction="sum", + ) / (paddle.sum(mask) * u.shape[1]) + return loss, y + + +class CFM(BASECFM): + def __init__( + self, + in_channels, + out_channel, + cfm_params, + decoder_params, + n_spks=1, + spk_emb_dim=64, + ): + super().__init__( + n_feats=in_channels, + cfm_params=cfm_params, + n_spks=n_spks, + spk_emb_dim=spk_emb_dim, + ) + in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) + self.estimator = Decoder( + in_channels=in_channels, out_channels=out_channel, **decoder_params + ) diff --git a/paddlespeech/t2s/modules/flow/length_regulator.py b/paddlespeech/t2s/modules/flow/length_regulator.py new file mode 100644 index 000000000..db6a35818 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/length_regulator.py @@ -0,0 +1,91 @@ +from typing import Tuple + +import paddle + +from cosyvoice.utils.mask import make_pad_mask + +############################## 相关utils函数,如下 ############################## + +def _Tensor_max(self, *args, **kwargs): + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.maximum(self, *args, **kwargs) + elif len(args) == 1 and isinstance(args[0], paddle.Tensor): + ret = paddle.maximum(self, *args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 1: + ret = paddle.max(self, *args, **kwargs), paddle.argmax(self, *args, **kwargs) + else: + ret = paddle.max(self, *args, **kwargs) + + return ret + +setattr(paddle.Tensor, "_max", _Tensor_max) +############################## 相关utils函数,如上 ############################## + + + +class InterpolateRegulator(paddle.nn.Layer): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + out_channels: int = None, + groups: int = 1, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = paddle.nn.LayerList(sublayers=[]) + if len(sampling_ratios) > 0: + for _ in sampling_ratios: + module = paddle.nn.Conv1d(channels, channels, 3, 1, 1) + norm = paddle.nn.GroupNorm(num_groups=groups, num_channels=channels) + act = paddle.nn.Mish() + model.extend([module, norm, act]) + model.append(paddle.nn.Conv1d(channels, out_channels, 1, 1)) + self.model = paddle.nn.Sequential(*model) + + def forward(self, x, ylens=None): + mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) + x = paddle.nn.functional.interpolate( + x=x.transpose(1, 2).contiguous(), size=ylens._max(), mode="linear" + ) + out = self.model(x).transpose(1, 2).contiguous() + olens = ylens + return out * mask, olens + + def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): + if x2.shape[1] > 40: + x2_head = paddle.nn.functional.interpolate( + x=x2[:, :20].transpose(1, 2).contiguous(), + size=int(20 / input_frame_rate * 22050 / 256), + mode="linear", + ) + x2_mid = paddle.nn.functional.interpolate( + x=x2[:, 20:-20].transpose(1, 2).contiguous(), + size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, + mode="linear", + ) + x2_tail = paddle.nn.functional.interpolate( + x=x2[:, -20:].transpose(1, 2).contiguous(), + size=int(20 / input_frame_rate * 22050 / 256), + mode="linear", + ) + x2 = paddle.cat([x2_head, x2_mid, x2_tail], dim=2) + else: + x2 = paddle.nn.functional.interpolate( + x=x2.transpose(1, 2).contiguous(), size=mel_len2, mode="linear" + ) + if x1.shape[1] != 0: + x1 = paddle.nn.functional.interpolate( + x=x1.transpose(1, 2).contiguous(), size=mel_len1, mode="linear" + ) + x = paddle.cat([x1, x2], dim=2) + else: + x = x2 + out = self.model(x).transpose(1, 2).contiguous() + return out, mel_len1 + mel_len2 \ No newline at end of file diff --git a/paddlespeech/t2s/modules/flow/lora.py b/paddlespeech/t2s/modules/flow/lora.py new file mode 100644 index 000000000..557f9ff15 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/lora.py @@ -0,0 +1,123 @@ +from typing import Optional, Tuple, Union + +import paddle + + +class LoRALinearLayer(paddle.nn.Layer): + """ + A linear layer that is used with LoRA. + + Parameters: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + rank (`int`, `optional`, defaults to 4): + The rank of the LoRA layer. + network_alpha (`float`, `optional`, defaults to `None`): + The value of the network alpha used for stable learning and preventing underflow. This value has the same + meaning as the `--network_alpha` option in the kohya-ss trainer script. See + https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + device (`torch.device`, `optional`, defaults to `None`): + The device to use for the layer's weights. + dtype (`torch.dtype`, `optional`, defaults to `None`): + The dtype to use for the layer's weights. + """ + + def __init__( + self, + in_features: int, + out_features: int, + rank: int = 4, + network_alpha: Optional[float] = None, + device: Optional[Union[paddle.CPUPlace, paddle.CUDAPlace, str]] = None, + dtype: Optional[paddle.dtype] = None, + ): + super().__init__() + self.down = paddle.nn.Linear( + in_features=in_features, out_features=rank, bias_attr=False + ) + self.up = paddle.nn.Linear( + in_features=rank, out_features=out_features, bias_attr=False + ) + self.network_alpha = network_alpha + self.rank = rank + self.out_features = out_features + self.in_features = in_features + paddle.nn.init.normal_(self.down.weight, std=1 / rank) + paddle.nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + return up_hidden_states.to(orig_dtype) + + + +class LoRACompatibleLinear(paddle.nn.Linear): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): + if self.lora_layer is None: + return + dtype, device = self.weight.data.dtype, self.weight.data.place + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + fused_weight = ( + w_orig + lora_scale * paddle.bmm(w_up[None, :], w_down[None, :])[0] + ) + if safe_fusing and paddle.isnan(fused_weight).any().item(): + raise ValueError( + f"This LoRA weight seems to be broken. Encountered NaN values when trying to fuse LoRA weights for {self}.LoRA weights will not be fused." + ) + self.weight.data = fused_weight.to(device=device, dtype=dtype) + self.lora_layer = None + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not ( + getattr(self, "w_up", None) is not None + and getattr(self, "w_down", None) is not None + ): + return + fused_weight = self.weight.data + dtype, device = fused_weight.dtype, fused_weight.place + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + unfused_weight = ( + fused_weight.float() + - self._lora_scale * paddle.bmm(w_up[None, :], w_down[None, :])[0] + ) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + self.w_up = None + self.w_down = None + + def forward( + self, hidden_states: paddle.Tensor, scale: float = 1.0 + ) -> paddle.Tensor: + if self.lora_layer is None: + out = super().forward(hidden_states) + return out + else: + out = super().forward(hidden_states) + scale * self.lora_layer( + hidden_states + ) + return out diff --git a/paddlespeech/t2s/modules/flow/mask.py b/paddlespeech/t2s/modules/flow/mask.py new file mode 100644 index 000000000..697ca0eb5 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/mask.py @@ -0,0 +1,287 @@ +import paddle + +def device2str(type=None, index=None, *, device=None): + type = device if device else type + if isinstance(type, int): + type = f'gpu:{type}' + elif isinstance(type, str): + if 'cuda' in type: + type = type.replace('cuda', 'gpu') + if 'cpu' in type: + type = 'cpu' + elif index is not None: + type = f'{type}:{index}' + elif isinstance(type, paddle.CPUPlace) or (type is None): + type = 'cpu' + elif isinstance(type, paddle.CUDAPlace): + type = f'gpu:{type.get_device_id()}' + + return type + +def _Tensor_max(self, *args, **kwargs): + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.maximum(self, *args, **kwargs) + elif len(args) == 1 and isinstance(args[0], paddle.Tensor): + ret = paddle.maximum(self, *args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 1: + ret = paddle.max(self, *args, **kwargs), paddle.argmax(self, *args, **kwargs) + else: + ret = paddle.max(self, *args, **kwargs) + + return ret + +setattr(paddle.Tensor, "_max", _Tensor_max) + + + +""" +def subsequent_mask( + size: int, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + ""\"Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming is not necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + ""\" + ret = torch.ones(size, size, device=device, dtype=torch.bool) + return torch.tril(ret) +""" + + +def subsequent_mask( +>>>>>> size: int, device: torch.device = device2str("cpu") +) -> paddle.Tensor: + """Create mask for subsequent steps (size, size). + + This mask is used only in decoder which works in an auto-regressive mode. + This means the current step could only do attention with its left steps. + + In encoder, fully attention is used when streaming isnot necessary and + the sequence is not long. In this case, no attention mask is needed. + + When streaming is need, chunk-based attention is used in encoder. See + subsequent_chunk_mask for the chunk-based attention mask. + + Args: + size (int): size of mask + str device (str): "cpu" or "cuda" or torch.Tensor.device + dtype (torch.device): result dtype + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + arange = paddle.arange(size, device=device) + mask = arange.expand(size, size) + arange = arange.unsqueeze(-1) + mask = mask <= arange + return mask + + +def subsequent_chunk_mask_deprecated( + size: int, + chunk_size: int, + num_left_chunks: int = -1, +>>>>>> device: torch.device = device2str("cpu"), +) -> paddle.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = paddle.zeros(size, size, device=device, dtype=paddle.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, +>>>>>> device: torch.device = device2str("cpu"), +) -> paddle.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + + Returns: + torch.Tensor: mask + + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + pos_idx = paddle.arange(size, device=device) + block_value = ( + paddle.div(pos_idx, chunk_size, rounding_mode="trunc") + 1 + ) * chunk_size + ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) + return ret + + +def add_optional_chunk_mask( + xs: paddle.Tensor, + masks: paddle.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int, + enable_full_context: bool = True, +): + """Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + enable_full_context (bool): + True: chunk size is either [1, 25] or full context(max_len) + False: chunk size ~ U[1, 25] + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + chunk_size = paddle.randint(low=1, high=max_len, shape=(1,)).item() + num_left_chunks = -1 + if chunk_size > max_len // 2 and enable_full_context: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = paddle.randint( + low=0, high=max_left_chunks, shape=(1,) + ).item() + chunk_masks = subsequent_chunk_mask( + xs.size(1), chunk_size, num_left_chunks, xs.place + ) + chunk_masks = chunk_masks.unsqueeze(0) + chunk_masks = masks & chunk_masks + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask( + xs.size(1), static_chunk_size, num_left_chunks, xs.place + ) + chunk_masks = chunk_masks.unsqueeze(0) + chunk_masks = masks & chunk_masks + else: + chunk_masks = masks + assert chunk_masks.dtype == paddle.bool + if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: + print( + "get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!" + ) + chunk_masks[chunk_masks.sum(dim=-1) == 0] = True + return chunk_masks + + +def make_pad_mask(lengths: paddle.Tensor, max_len: int = 0) -> paddle.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.size(0) + max_len = max_len if max_len > 0 else lengths._max().item() + seq_range = paddle.arange(0, max_len, dtype=paddle.int64, device=lengths.place) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask \ No newline at end of file diff --git a/paddlespeech/t2s/modules/flow/matcha_decoder.py b/paddlespeech/t2s/modules/flow/matcha_decoder.py new file mode 100644 index 000000000..7ecb100b2 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/matcha_decoder.py @@ -0,0 +1,445 @@ +import math +from typing import Optional + +import einops +import paddle +from conformer import ConformerBlock +from matcha.models.components.transformer import BasicTransformerBlock +ACTIVATION_FUNCTIONS = { + "swish": paddle.nn.SiLU(), + "silu": paddle.nn.SiLU(), + "mish": paddle.nn.Mish(), + "gelu": paddle.nn.GELU(), + "relu": paddle.nn.ReLU(), +} +def get_activation(act_fn: str) -> paddle.nn.Layer: + """Helper function to get activation function from string. + + Args: + act_fn (str): Name of activation function. + + Returns: + nn.Module: Activation function. + """ + act_fn = act_fn.lower() + if act_fn in ACTIVATION_FUNCTIONS: + return ACTIVATION_FUNCTIONS[act_fn] + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + +class SinusoidalPosEmb(paddle.nn.Layer): + def __init__(self, dim): + super().__init__() + self.dim = dim + assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + device = x.place + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = paddle.exp(x=paddle.arange(half_dim, device=device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = paddle.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Block1D(paddle.nn.Layer): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = paddle.nn.Sequential( + paddle.nn.Conv1d(dim, dim_out, 3, padding=1), + paddle.nn.GroupNorm(num_groups=groups, num_channels=dim_out), + paddle.nn.Mish(), + ) + + def forward(self, x, mask): + output = self.block(x * mask) + return output * mask + + +class ResnetBlock1D(paddle.nn.Layer): + def __init__(self, dim, dim_out, time_emb_dim, groups=8): + super().__init__() + self.mlp = paddle.nn.Sequential( + paddle.nn.Mish(), + paddle.nn.Linear(in_features=time_emb_dim, out_features=dim_out), + ) + self.block1 = Block1D(dim, dim_out, groups=groups) + self.block2 = Block1D(dim_out, dim_out, groups=groups) + self.res_conv = paddle.nn.Conv1d(dim, dim_out, 1) + + def forward(self, x, mask, time_emb): + h = self.block1(x, mask) + h += self.mlp(time_emb).unsqueeze(-1) + h = self.block2(h, mask) + output = h + self.res_conv(x * mask) + return output + + +class Downsample1D(paddle.nn.Layer): + def __init__(self, dim): + super().__init__() + self.conv = paddle.nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class TimestepEmbedding(paddle.nn.Layer): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + ): + super().__init__() + self.linear_1 = paddle.nn.Linear( + in_features=in_channels, out_features=time_embed_dim + ) + if cond_proj_dim is not None: + self.cond_proj = paddle.nn.Linear( + in_features=cond_proj_dim, out_features=in_channels, bias_attr=False + ) + else: + self.cond_proj = None + self.act = get_activation(act_fn) + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = paddle.nn.Linear( + in_features=time_embed_dim, out_features=time_embed_dim_out + ) + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + if self.act is not None: + sample = self.act(sample) + sample = self.linear_2(sample) + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +class Upsample1D(paddle.nn.Layer): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__( + self, + channels, + use_conv=False, + use_conv_transpose=True, + out_channels=None, + name="conv", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.conv = None + if use_conv_transpose: + self.conv = paddle.nn.Conv1DTranspose( + in_channels=channels, + out_channels=self.out_channels, + kernel_size=4, + stride=2, + padding=1, + ) + elif use_conv: + self.conv = paddle.nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, inputs): + assert inputs.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + outputs = paddle.nn.functional.interpolate( + x=inputs, scale_factor=2.0, mode="nearest" + ) + if self.use_conv: + outputs = self.conv(outputs) + return outputs + + +class ConformerWrapper(ConformerBlock): + def __init__( + self, + *, + dim, + dim_head=64, + heads=8, + ff_mult=4, + conv_expansion_factor=2, + conv_kernel_size=31, + attn_dropout=0, + ff_dropout=0, + conv_dropout=0, + conv_causal=False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward( + self, + hidden_states, + attention_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + ): + return super().forward(x=hidden_states, mask=attention_mask.bool()) + + +class Decoder(paddle.nn.Layer): + def __init__( + self, + in_channels, + out_channels, + channels=(256, 256), + dropout=0.05, + attention_head_dim=64, + n_blocks=1, + num_mid_blocks=2, + num_heads=4, + act_fn="snake", + down_block_type="transformer", + mid_block_type="transformer", + up_block_type="transformer", + ): + super().__init__() + channels = tuple(channels) + self.in_channels = in_channels + self.out_channels = out_channels + self.time_embeddings = SinusoidalPosEmb(in_channels) + time_embed_dim = channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=in_channels, time_embed_dim=time_embed_dim, act_fn="silu" + ) + self.down_blocks = paddle.nn.LayerList(sublayers=[]) + self.mid_blocks = paddle.nn.LayerList(sublayers=[]) + self.up_blocks = paddle.nn.LayerList(sublayers=[]) + output_channel = in_channels + for i in range(len(channels)): + input_channel = output_channel + output_channel = channels[i] + is_last = i == len(channels) - 1 + resnet = ResnetBlock1D( + dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim + ) + transformer_blocks = paddle.nn.LayerList( + sublayers=[ + self.get_block( + down_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + downsample = ( + Downsample1D(output_channel) + if not is_last + else paddle.nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.down_blocks.append( + paddle.nn.LayerList(sublayers=[resnet, transformer_blocks, downsample]) + ) + for i in range(num_mid_blocks): + input_channel = channels[-1] + out_channels = channels[-1] + resnet = ResnetBlock1D( + dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim + ) + transformer_blocks = paddle.nn.LayerList( + sublayers=[ + self.get_block( + mid_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + self.mid_blocks.append( + paddle.nn.LayerList(sublayers=[resnet, transformer_blocks]) + ) + channels = channels[::-1] + (channels[0],) + for i in range(len(channels) - 1): + input_channel = channels[i] + output_channel = channels[i + 1] + is_last = i == len(channels) - 2 + resnet = ResnetBlock1D( + dim=2 * input_channel, + dim_out=output_channel, + time_emb_dim=time_embed_dim, + ) + transformer_blocks = paddle.nn.LayerList( + sublayers=[ + self.get_block( + up_block_type, + output_channel, + attention_head_dim, + num_heads, + dropout, + act_fn, + ) + for _ in range(n_blocks) + ] + ) + upsample = ( + Upsample1D(output_channel, use_conv_transpose=True) + if not is_last + else paddle.nn.Conv1d(output_channel, output_channel, 3, padding=1) + ) + self.up_blocks.append( + paddle.nn.LayerList(sublayers=[resnet, transformer_blocks, upsample]) + ) + self.final_block = Block1D(channels[-1], channels[-1]) + self.final_proj = paddle.nn.Conv1d(channels[-1], self.out_channels, 1) + self.initialize_weights() + + @staticmethod + def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): + if block_type == "conformer": + block = ConformerWrapper( + dim=dim, + dim_head=attention_head_dim, + heads=num_heads, + ff_mult=1, + conv_expansion_factor=2, + ff_dropout=dropout, + attn_dropout=dropout, + conv_dropout=dropout, + conv_kernel_size=31, + ) + elif block_type == "transformer": + block = BasicTransformerBlock( + dim=dim, + num_attention_heads=num_heads, + attention_head_dim=attention_head_dim, + dropout=dropout, + activation_fn=act_fn, + ) + else: + raise ValueError(f"Unknown block type {block_type}") + return block + + def initialize_weights(self): + for m in self.sublayers(): + if isinstance(m, paddle.nn.Conv1d): + paddle.nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + paddle.nn.init.constant_(m.bias, 0) + elif isinstance(m, paddle.nn.GroupNorm): + paddle.nn.init.constant_(m.weight, 1) + paddle.nn.init.constant_(m.bias, 0) + elif isinstance(m, paddle.nn.Linear): + paddle.nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + paddle.nn.init.constant_(m.bias, 0) + + def forward(self, x, mask, mu, t, spks=None, cond=None): + """Forward pass of the UNet1DConditional model. + + Args: + x (torch.Tensor): shape (batch_size, in_channels, time) + mask (_type_): shape (batch_size, 1, time) + t (_type_): shape (batch_size) + spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. + cond (_type_, optional): placeholder for future use. Defaults to None. + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + t = self.time_embeddings(t) + t = self.time_mlp(t) + x = einops.pack([x, mu], "b * t")[0] + if spks is not None: + spks = einops.repeat(spks, "b c -> b c t", t=x.shape[-1]) + x = einops.pack([x, spks], "b * t")[0] + hiddens = [] + masks = [mask] + for resnet, transformer_blocks, downsample in self.down_blocks: + mask_down = masks[-1] + x = resnet(x, mask_down, t) + x = einops.rearrange(x, "b c t -> b t c") + mask_down = einops.rearrange(mask_down, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, attention_mask=mask_down, timestep=t + ) + x = einops.rearrange(x, "b t c -> b c t") + mask_down = einops.rearrange(mask_down, "b t -> b 1 t") + hiddens.append(x) + x = downsample(x * mask_down) + masks.append(mask_down[:, :, ::2]) + masks = masks[:-1] + mask_mid = masks[-1] + for resnet, transformer_blocks in self.mid_blocks: + x = resnet(x, mask_mid, t) + x = einops.rearrange(x, "b c t -> b t c") + mask_mid = einops.rearrange(mask_mid, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, attention_mask=mask_mid, timestep=t + ) + x = einops.rearrange(x, "b t c -> b c t") + mask_mid = einops.rearrange(mask_mid, "b t -> b 1 t") + for resnet, transformer_blocks, upsample in self.up_blocks: + mask_up = masks.pop() + x = resnet(einops.pack([x, hiddens.pop()], "b * t")[0], mask_up, t) + x = einops.rearrange(x, "b c t -> b t c") + mask_up = einops.rearrange(mask_up, "b 1 t -> b t") + for transformer_block in transformer_blocks: + x = transformer_block( + hidden_states=x, attention_mask=mask_up, timestep=t + ) + x = einops.rearrange(x, "b t c -> b c t") + mask_up = einops.rearrange(mask_up, "b t -> b 1 t") + x = upsample(x * mask_up) + x = self.final_block(x, mask_up) + output = self.final_proj(x * mask_up) + return output * mask diff --git a/paddlespeech/t2s/modules/flow/matcha_transformer.py b/paddlespeech/t2s/modules/flow/matcha_transformer.py new file mode 100644 index 000000000..db2214729 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/matcha_transformer.py @@ -0,0 +1,313 @@ +from typing import Any, Dict, Optional +from .attention_processor import Attention +import paddle +from paddle import nn +from paddlespeech.t2s.modules.flow.lora import LoRACompatibleLinear +class SnakeBeta(paddle.nn.Layer): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + """ + + def __init__( + self, + in_features, + out_features, + alpha=1.0, + alpha_trainable=True, + alpha_logscale=True, + ): + """ + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + """ + super().__init__() + self.in_features = ( + out_features if isinstance(out_features, list) else [out_features] + ) + self.proj = LoRACompatibleLinear( + in_features, out_features + ) + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + self.alpha = paddle.nn.parameter.Parameter( + paddle.zeros(self.in_features) * alpha + ) + self.beta = paddle.nn.parameter.Parameter( + paddle.zeros(self.in_features) * alpha + ) + else: + self.alpha = paddle.nn.parameter.Parameter( + paddle.ones(self.in_features) * alpha + ) + self.beta = paddle.nn.parameter.Parameter( + paddle.ones(self.in_features) * alpha + ) + self.alpha.stop_gradient = not alpha_trainable + self.beta.stop_gradient = not alpha_trainable + self.no_div_by_zero = 1e-09 + + def forward(self, x): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + x = self.proj(x) + if self.alpha_logscale: + alpha = paddle.exp(x=self.alpha) + beta = paddle.exp(x=self.beta) + else: + alpha = self.alpha + beta = self.beta + x = x + 1.0 / (beta + self.no_div_by_zero) * paddle.pow( + paddle.sin(x * alpha), 2 + ) + return x +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +class GELU(nn.Layer): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias_attr=bias) + self.approximate = approximate + + def gelu(self, gate: paddle.Tensor) -> paddle.Tensor: + if self.approximate == "tanh": + approximate_bool = True + else: + approximate_bool = False + + if gate.dtype == paddle.float16: + return F.gelu(gate.astype(paddle.float32), approximate=approximate_bool).astype(paddle.float16) + else: + return F.gelu(gate, approximate=approximate_bool) + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + +class FeedForward(paddle.nn.Layer): + """ + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + act_fn = GELU(dim, inner_dim) + self.net = paddle.nn.LayerList(sublayers=[]) + self.net.append(act_fn) + self.net.append(paddle.nn.Dropout(p=dropout)) + self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) + if final_dropout: + self.net.append(paddle.nn.Dropout(p=dropout)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + + +class BasicTransformerBlock(paddle.nn.Layer): + """ + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm_zero = ( + num_embeds_ada_norm is not None and norm_type == "ada_norm_zero" + ) + self.use_ada_layer_norm = ( + num_embeds_ada_norm is not None and norm_type == "ada_norm" + ) + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + self.norm1 = paddle.nn.LayerNorm( + normalized_shape=dim, + weight_attr=norm_elementwise_affine, + bias_attr=norm_elementwise_affine, + ) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.norm2 = None + self.attn2 = None + self.norm3 = paddle.nn.LayerNorm( + normalized_shape=dim, + weight_attr=norm_elementwise_affine, + bias_attr=norm_elementwise_affine, + ) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + ) + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + encoder_hidden_states: Optional[paddle.Tensor] = None, + encoder_attention_mask: Optional[paddle.Tensor] = None, + timestep: Optional[paddle.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[paddle.Tensor] = None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + (norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=encoder_attention_mask + if self.only_cross_attention + else attention_mask, + **cross_attention_kwargs, + ) + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + if self._chunk_size is not None: + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ff_output = paddle.cat( + [ + self.ff(hid_slice) + for hid_slice in norm_hidden_states.chunk( + num_chunks, dim=self._chunk_dim + ) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + hidden_states = ff_output + hidden_states + return hidden_states diff --git a/paddlespeech/t2s/modules/flow/normalization.py b/paddlespeech/t2s/modules/flow/normalization.py new file mode 100644 index 000000000..ba0ccbc55 --- /dev/null +++ b/paddlespeech/t2s/modules/flow/normalization.py @@ -0,0 +1,276 @@ +import numbers +from typing import Dict, Optional, Tuple + +import paddle +from paddle import nn +from .activations import get_activation +from .embeddings import (CombinedTimestepLabelEmbeddings, + PixArtAlphaCombinedTimestepSizeEmbeddings) +def get_activation(act_fn): + if act_fn == "silu": + return nn.Silu() + elif act_fn == "mish": + return nn.Mish() + elif act_fn == "relu": + return nn.ReLU() + elif act_fn == "gelu": + return nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") + +class AdaLayerNorm(paddle.nn.Layer): + """ + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, num_embeddings: int): + super().__init__() + self.emb = paddle.nn.Embedding(num_embeddings, embedding_dim) + self.silu = paddle.nn.SiLU() + self.linear = paddle.nn.Linear( + in_features=embedding_dim, out_features=embedding_dim * 2 + ) + self.norm = paddle.nn.LayerNorm( + normalized_shape=embedding_dim, weight_attr=False, bias_attr=False + ) + + def forward(self, x: paddle.Tensor, timestep: paddle.Tensor) -> paddle.Tensor: + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = paddle.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x + + +class AdaLayerNormZero(paddle.nn.Layer): + """ + Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None): + super().__init__() + if num_embeddings is not None: + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + else: + self.emb = None + self.silu = paddle.nn.SiLU() + self.linear = paddle.nn.Linear( + in_features=embedding_dim, out_features=6 * embedding_dim, bias_attr=True + ) + self.norm = paddle.nn.LayerNorm( + normalized_shape=embedding_dim, + weight_attr=False, + bias_attr=False, + epsilon=1e-06, + ) + + def forward( + self, + x: paddle.Tensor, + timestep: Optional[paddle.Tensor] = None, + class_labels: Optional[paddle.LongTensor] = None, + hidden_dtype: Optional[paddle.dtype] = None, + emb: Optional[paddle.Tensor] = None, + ) -> Tuple[ + paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor + ]: + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk( + 6, dim=1 + ) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormSingle(paddle.nn.Layer): + """ + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, + size_emb_dim=embedding_dim // 3, + use_additional_conditions=use_additional_conditions, + ) + self.silu = paddle.nn.SiLU() + self.linear = paddle.nn.Linear( + in_features=embedding_dim, out_features=6 * embedding_dim, bias_attr=True + ) + + def forward( + self, + timestep: paddle.Tensor, + added_cond_kwargs: Optional[Dict[str, paddle.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[paddle.dtype] = None, + ) -> Tuple[ + paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor + ]: + embedded_timestep = self.emb( + timestep, + **added_cond_kwargs, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class AdaGroupNorm(paddle.nn.Layer): + """ + GroupNorm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + num_groups (`int`): The number of groups to separate the channels into. + act_fn (`str`, *optional*, defaults to `None`): The activation function to use. + eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability. + """ + + def __init__( + self, + embedding_dim: int, + out_dim: int, + num_groups: int, + act_fn: Optional[str] = None, + eps: float = 1e-05, + ): + super().__init__() + self.num_groups = num_groups + self.eps = eps + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) + self.linear = paddle.nn.Linear( + in_features=embedding_dim, out_features=out_dim * 2 + ) + + def forward(self, x: paddle.Tensor, emb: paddle.Tensor) -> paddle.Tensor: + if self.act: + emb = self.act(emb) + emb = self.linear(emb) + emb = emb[:, :, None, None] + scale, shift = emb.chunk(2, dim=1) + x = paddle.nn.functional.group_norm( + x=x, num_groups=self.num_groups, epsilon=self.eps + ) + x = x * (1 + scale) + shift + return x + + +class AdaLayerNormContinuous(paddle.nn.Layer): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine=True, + eps=1e-05, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = paddle.nn.SiLU() + self.linear = paddle.nn.Linear( + in_features=conditioning_embedding_dim, + out_features=embedding_dim * 2, + bias_attr=bias, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward( + self, x: paddle.Tensor, conditioning_embedding: paddle.Tensor + ) -> paddle.Tensor: + emb = self.linear(self.silu(conditioning_embedding).to(x.dtype)) + scale, shift = paddle.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + +LayerNorm = paddle.nn.LayerNorm + + +class LayerNorm(paddle.nn.Layer): + def __init__( + self, + dim, + eps: float = 1e-05, + elementwise_affine: bool = True, + bias: bool = True, + ): + super().__init__() + self.eps = eps + if isinstance(dim, numbers.Integral): + dim = (dim,) + self.dim = paddle.Size(dim) + if elementwise_affine: + self.weight = paddle.nn.parameter.Parameter(paddle.ones(dim)) + self.bias = ( + paddle.nn.parameter.Parameter(paddle.zeros(dim)) if bias else None + ) + else: + self.weight = None + self.bias = None + + def forward(self, input): + return paddle.nn.functional.layer_norm( + input, self.dim, self.weight, self.bias, self.eps + ) + + +class RMSNorm(paddle.nn.Layer): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + self.eps = eps + if isinstance(dim, numbers.Integral): + dim = (dim,) + self.dim = paddle.Size(dim) + if elementwise_affine: + self.weight = paddle.nn.parameter.Parameter(paddle.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(paddle.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * paddle.rsqrt(variance + self.eps) + if self.weight is not None: + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + return hidden_states + + +class GlobalResponseNorm(paddle.nn.Layer): + def __init__(self, dim): + super().__init__() + self.gamma = paddle.nn.parameter.Parameter(paddle.zeros(1, 1, 1, dim)) + self.beta = paddle.nn.parameter.Parameter(paddle.zeros(1, 1, 1, dim)) + + def forward(self, x): + gx = paddle.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-06) + return self.gamma * (x * nx) + self.beta + x diff --git a/paddlespeech/t2s/modules/predictor/length_regulator.py b/paddlespeech/t2s/modules/predictor/length_regulator.py index bdfa18391..bb4b5a64d 100644 --- a/paddlespeech/t2s/modules/predictor/length_regulator.py +++ b/paddlespeech/t2s/modules/predictor/length_regulator.py @@ -108,7 +108,6 @@ class LengthRegulator(nn.Layer): Returns: Tensor: replicated input tensor based on durations (B, T*, D). """ - if alpha != 1.0: assert alpha > 0 ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha) diff --git a/paddlespeech/t2s/modules/tokenizer.py b/paddlespeech/t2s/modules/tokenizer.py new file mode 100644 index 000000000..53ea044f5 --- /dev/null +++ b/paddlespeech/t2s/modules/tokenizer.py @@ -0,0 +1,241 @@ +import base64 +import os +from functools import lru_cache +from paddlenlp.transformers import AutoTokenizer +import paddle +import tiktoken + +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", + "yue": "cantonese", + "minnan": "minnan", + "wuyu": "wuyu", + "dialect": "dialect", + "zh/en": "zh/en", + "en/zh": "en/zh", +} +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", + "mandarin": "zh", +} +AUDIO_EVENT = { + "ASR": "ASR", + "AED": "AED", + "SER": "SER", + "Speech": "Speech", + "/Speech": "/Speech", + "BGM": "BGM", + "/BGM": "/BGM", + "Laughter": "Laughter", + "/Laughter": "/Laughter", + "Applause": "Applause", + "/Applause": "/Applause", +} +EMOTION = {"HAPPY": "HAPPY", "SAD": "SAD", "ANGRY": "ANGRY", "NEUTRAL": "NEUTRAL"} +TTS_Vocal_Token = { + "TTS/B": "TTS/B", + "TTS/O": "TTS/O", + "TTS/Q": "TTS/Q", + "TTS/A": "TTS/A", + "TTS/CO": "TTS/CO", + "TTS/CL": "TTS/CL", + "TTS/H": "TTS/H", + **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}, +} + + +@lru_cache(maxsize=None) +def get_encoding(name: str = "gpt2", num_languages: int = 99): + vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") + ranks = { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in open(vocab_path) if line) + } + n_vocab = len(ranks) + special_tokens = {} + specials = [ + "<|endoftext|>", + "<|startoftranscript|>", + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], + *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())], + *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())], + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nospeech|>", + "<|notimestamps|>", + *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], + *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], + *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], + ] + for token in specials: + special_tokens[token] = n_vocab + n_vocab += 1 + return tiktoken.Encoding( + name=os.path.basename(vocab_path), + explicit_n_vocab=n_vocab, + pat_str="'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+", + mergeable_ranks=ranks, + special_tokens=special_tokens, + ) + + +class QwenTokenizer: + def __init__(self, skip_special_tokens=True): + super().__init__() + special_tokens = { + "eos_token": "<|endoftext|>", + "pad_token": "<|endoftext|>", + "additional_special_tokens": [ + "<|im_start|>", + "<|im_end|>", + "<|endofprompt|>", + "[breath]", + "", + "", + "[noise]", + "[laughter]", + "[cough]", + "[clucking]", + "[accent]", + "[quick_breath]", + "", + "", + "[hissing]", + "[sigh]", + "[vocalized-noise]", + "[lipsmack]", + "[mn]", + ], + } + self.special_tokens = special_tokens + self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") + self.tokenizer.add_special_tokens(special_tokens) + self.skip_special_tokens = skip_special_tokens + + def encode(self, text, **kwargs): + tokens = self.tokenizer([text], return_tensors="pd") + tokens = tokens["input_ids"][0].cpu().tolist() + return tokens + + def decode(self, tokens): + tokens = paddle.tensor(tokens, dtype=paddle.int64) + text = self.tokenizer.batch_decode( + [tokens], skip_special_tokens=self.skip_special_tokens + )[0] + return text + + +@lru_cache(maxsize=None) +def get_qwen_tokenizer(skip_special_tokens: bool) -> QwenTokenizer: + return QwenTokenizer(skip_special_tokens=skip_special_tokens) diff --git a/paddlespeech/t2s/modules/transformer/__init__.py b/paddlespeech/t2s/modules/transformer/__init__.py index abf198b97..a9cc79cc9 100644 --- a/paddlespeech/t2s/modules/transformer/__init__.py +++ b/paddlespeech/t2s/modules/transformer/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddlespeech/t2s/modules/transformer/activation.py b/paddlespeech/t2s/modules/transformer/activation.py new file mode 100644 index 000000000..c7363a406 --- /dev/null +++ b/paddlespeech/t2s/modules/transformer/activation.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import paddle + +"""Swish() activation function for Conformer.""" + + +class Swish(paddle.nn.Layer): + """Construct an Swish object.""" + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + """Return Swish activation function.""" + return x * paddle.nn.functional.sigmoid(x) + + +class Snake(paddle.nn.Layer): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = Snake(256) + >>> x = paddle.randn([1, 256, 100]) # Example input + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: number of input features (channel dimension) + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + self.alpha_logscale = alpha_logscale + + # 避免除零的小常数 + self.no_div_by_zero = 1e-9 + + # 初始化alpha的值:对数尺度下初始为0,线性尺度下初始为alpha + if self.alpha_logscale: + initial_value = 0.0 # 对数尺度下,初始化为0,前向传播中会进行exp运算 + else: + initial_value = alpha # 线性尺度下,直接初始化为alpha + + # 创建可训练参数alpha - 使用PaddlePaddle的方式 + # 注意:这里使用self.create_parameter而不是paddle.create_parameter + self.alpha = self.create_parameter( + shape=[in_features], # 参数形状为[in_features] + dtype='float32', # 数据类型 + default_initializer=paddle.nn.initializer.Constant(value=initial_value) # 初始化器 + ) + + # 设置参数是否需要梯度更新(是否可训练) + # 在PaddlePaddle中,通过设置stop_gradient来控制 + self.alpha.stop_gradient = not alpha_trainable + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + # 调整alpha的维度以匹配输入x: [B, C, T] -> alpha需要变为[1, C, 1] + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # 从[C]变为[1, C, 1] + + # 如果使用对数尺度,对alpha取指数 + if self.alpha_logscale: + alpha = paddle.exp(alpha) + + # 计算Snake激活函数 + # 公式: x + (1.0 / (alpha + epsilon)) * sin(x * alpha)^2 + sin_term = paddle.sin(x * alpha) + result = x + (1.0 / (alpha + self.no_div_by_zero)) * (sin_term ** 2) + + return result diff --git a/paddlespeech/t2s/modules/transformer/attention.py b/paddlespeech/t2s/modules/transformer/attention.py index 3237be1b6..5d883ab44 100644 --- a/paddlespeech/t2s/modules/transformer/attention.py +++ b/paddlespeech/t2s/modules/transformer/attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,11 +14,9 @@ # Modified from espnet(https://github.com/espnet/espnet) """Multi-Head Attention layer definition.""" import math - import numpy import paddle from paddle import nn - from paddlespeech.t2s.modules.masked_fill import masked_fill @@ -199,7 +197,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): x = x * paddle.tril(ones, t2 - t1)[None, None, :, :] return x - def forward(self, query, key, value, pos_emb, mask): + def forward(self, query, key, value, pos_emb, mask, cache): """Compute 'Scaled Dot Product Attention' with rel. positional encoding. Args: @@ -220,6 +218,11 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): q, k, v = self.forward_qkv(query, key, value) # (batch, time1, head, d_k) q = q.transpose([0, 2, 1, 3]) + if cache is not None and cache.shape[0] > 0: + key_cache, value_cache = paddle.split(cache, num_or_sections=2, axis=-1) + k = paddle.concat([key_cache, k], axis=2) + v = paddle.concat([value_cache, v], axis=2) + new_cache = paddle.concat([k, v], axis=-1) n_batch_pos = paddle.shape(pos_emb)[0] p = self.linear_pos(pos_emb).reshape( [n_batch_pos, -1, self.h, self.d_k]) @@ -243,7 +246,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): # (batch, head, time1, time2) scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask) + return self.forward_attention(v, scores, mask), new_cache class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): diff --git a/paddlespeech/t2s/modules/transformer/convolution.py b/paddlespeech/t2s/modules/transformer/convolution.py new file mode 100644 index 000000000..05027c6cf --- /dev/null +++ b/paddlespeech/t2s/modules/transformer/convolution.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +"""ConvolutionModule definition.""" +from typing import Tuple + + +class ConvolutionModule(paddle.nn.Layer): + """ConvolutionModule in Conformer model.""" + + def __init__( + self, + channels: int, + kernel_size: int = 15, + activation: paddle.nn.Layer = paddle.nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True, + ): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + super().__init__() + self.pointwise_conv1 = paddle.nn.Conv1d( + channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias + ) + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = paddle.nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + assert norm in ["batch_norm", "layer_norm"] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = paddle.nn.BatchNorm1D(num_features=channels) + else: + self.use_layer_norm = True + self.norm = paddle.nn.LayerNorm(normalized_shape=channels) + self.pointwise_conv2 = paddle.nn.Conv1d( + channels, channels, kernel_size=1, stride=1, padding=0, bias=bias + ) + self.activation = activation + + def forward( + self, + x: paddle.Tensor, + mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + cache: paddle.Tensor = paddle.zeros((0, 0, 0)), + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + x = x.transpose(1, 2) + if mask_pad.size(2) > 0: + x.masked_fill_(~mask_pad, 0.0) + if self.lorder > 0: + if cache.size(2) == 0: + x = paddle.compat.pad(x, (self.lorder, 0), "constant", 0.0) + else: + assert cache.size(0) == x.size(0) + assert cache.size(1) == x.size(1) + x = paddle.cat((cache, x), dim=2) + assert x.size(2) > self.lorder + new_cache = x[:, :, -self.lorder :] + else: + new_cache = paddle.zeros((0, 0, 0), dtype=x.dtype, device=x.place) + x = self.pointwise_conv1(x) + x = paddle.nn.functional.glu(x=x, axis=1) + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + if mask_pad.size(2) > 0: + x.masked_fill_(~mask_pad, 0.0) + return x.transpose(1, 2), new_cache diff --git a/paddlespeech/t2s/modules/transformer/embedding.py b/paddlespeech/t2s/modules/transformer/embedding.py index e4331cff0..24a076f31 100644 --- a/paddlespeech/t2s/modules/transformer/embedding.py +++ b/paddlespeech/t2s/modules/transformer/embedding.py @@ -14,7 +14,7 @@ # Modified from espnet(https://github.com/espnet/espnet) """Positional Encoding Module.""" import math - +from typing import Union import paddle from paddle import nn @@ -131,6 +131,108 @@ class ScaledPositionalEncoding(PositionalEncoding): x = x + self.alpha * self.pe[:, :T] return self.dropout(x) +class EspnetRelPositionalEncoding(paddle.nn.Layer): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Construct an PositionalEncoding object.""" + super(EspnetRelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = paddle.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(paddle.to_tensor([0.0]).expand([1, max_len])) + + def extend_pe(self, x: paddle.Tensor): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.shape[1] >= x.shape[1] * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.place != x.place: + self.pe = self.pe.to(dtype=x.dtype, device=x.place) + return + pe_positive = paddle.zeros([x.shape[1], self.d_model]) + pe_negative = paddle.zeros([x.shape[1], self.d_model]) + position = paddle.arange(0, x.shape[1], dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + x=paddle.arange(0, self.d_model, 2, dtype=paddle.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe_positive[:, 0::2] = paddle.sin(position * div_term) + pe_positive[:, 1::2] = paddle.cos(position * div_term) + pe_negative[:, 0::2] = paddle.sin(-1 * position * div_term) + pe_negative[:, 1::2] = paddle.cos(-1 * position * div_term) + pe_positive = paddle.flip(x=pe_positive, axis=[0]).unsqueeze(0) + pe_negative = pe_negative[1:].unsqueeze(0) + pe = paddle.cat([pe_positive, pe_negative], dim=1) + self.pe = pe.to(device=x.place, dtype=x.dtype) + + def forward( + self, x: paddle.Tensor, offset: Union[int, paddle.Tensor] = 0 + ) -> tuple[paddle.Tensor, paddle.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.position_encoding(size=x.shape[1], offset=offset) + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding( + self, offset: Union[int, paddle.Tensor], size: int + ) -> paddle.Tensor: + """For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + if isinstance(offset, int): + pos_emb = self.pe[ + :, + self.pe.shape[1] // 2 + - size + - offset + + 1 : self.pe.shape[1] // 2 + + size + + offset, + ] + elif isinstance(offset, paddle.Tensor): + pos_emb = self.pe[ + :, + self.pe.shape[1] // 2 + - size + - offset + + 1 : self.pe.shape[1] // 2 + + size + + offset, + ] + return pos_emb class RelPositionalEncoding(nn.Layer): """Relative positional encoding module (new implementation). diff --git a/paddlespeech/t2s/modules/transformer/encoder_layer.py b/paddlespeech/t2s/modules/transformer/encoder_layer.py index 63494b0de..b989ffd50 100644 --- a/paddlespeech/t2s/modules/transformer/encoder_layer.py +++ b/paddlespeech/t2s/modules/transformer/encoder_layer.py @@ -15,7 +15,7 @@ """Encoder self-attention layer definition.""" import paddle from paddle import nn - +from typing import Optional class EncoderLayer(nn.Layer): """Encoder layer module. @@ -111,3 +111,118 @@ class EncoderLayer(nn.Layer): x = paddle.concat([cache, x], axis=1) return x, mask + +class ConformerEncoderLayer(paddle.nn.Layer): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module + instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: paddle.nn.Layer, + feed_forward: Optional[paddle.nn.Layer] = None, + feed_forward_macaron: Optional[paddle.nn.Layer] = None, + conv_module: Optional[paddle.nn.Layer] = None, + dropout_rate: float = 0.1, + normalize_before: bool = True, + ): + """Construct an EncoderLayer object.""" + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + self.norm_mha = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + if feed_forward_macaron is not None: + self.norm_ff_macaron = paddle.nn.LayerNorm( + normalized_shape=size, epsilon=1e-12 + ) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + self.norm_final = paddle.nn.LayerNorm(normalized_shape=size, epsilon=1e-12) + self.dropout = paddle.nn.Dropout(p=dropout_rate) + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + x: paddle.Tensor, + mask: paddle.Tensor, + pos_emb: paddle.Tensor, + mask_pad: paddle.Tensor = paddle.ones((0, 0, 0), dtype=paddle.bool), + att_cache: paddle.Tensor = paddle.zeros((0, 0, 0, 0)), + cnn_cache: paddle.Tensor = paddle.zeros((0, 0, 0, 0)), + ) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute encoded features. + + Args: + x (torch.Tensor): (#batch, time, size) + mask (torch.Tensor): Mask tensor for the input (#batch, time,time), + (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): positional encoding, must not be None + for ConformerEncoderLayer. + mask_pad (torch.Tensor): batch padding mask used for conv module. + (#batch, 1,time), (0, 0, 0) means fake mask. + att_cache (torch.Tensor): Cache tensor of the KEY & VALUE + (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. + cnn_cache (torch.Tensor): Convolution cache in conformer layer + (#batch=1, size, cache_t2) + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time, time). + torch.Tensor: att_cache tensor, + (#batch=1, head, cache_t1 + time, d_k * 2). + torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). + """ + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + residual = x + if self.normalize_before: + x = self.norm_mha(x) + x_att, new_att_cache = self.self_attn(x, x, x, pos_emb, mask,att_cache) + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + new_cnn_cache = paddle.zeros([0, 0, 0], dtype=x.dtype) + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.norm_conv(x) + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + if self.conv_module is not None: + x = self.norm_final(x) + return x, mask, new_att_cache, new_cnn_cache \ No newline at end of file diff --git a/paddlespeech/t2s/modules/transformer/espnet.py b/paddlespeech/t2s/modules/transformer/espnet.py new file mode 100644 index 000000000..229b1bd60 --- /dev/null +++ b/paddlespeech/t2s/modules/transformer/espnet.py @@ -0,0 +1,275 @@ +import paddle + +"""Positonal Encoding Module.""" +import math +from typing import Tuple, Union + +import numpy as np + + +class PositionalEncoding(paddle.nn.Layer): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) + PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) + """ + + def __init__( + self, + d_model: int, + dropout_rate: float, + max_len: int = 5000, + reverse: bool = False, + ): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = paddle.nn.Dropout(p=dropout_rate) + self.max_len = max_len + self.pe = paddle.zeros(self.max_len, self.d_model) + position = paddle.arange(0, self.max_len, dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + x=paddle.arange(0, self.d_model, 2, dtype=paddle.float32) + * -(math.log(10000.0) / self.d_model) + ) + self.pe[:, 0::2] = paddle.sin(position * div_term) + self.pe[:, 1::2] = paddle.cos(position * div_term) + self.pe = self.pe.unsqueeze(0) + + def forward( + self, x: paddle.Tensor, offset: Union[int, paddle.Tensor] = 0 + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + offset (int, torch.tensor): position offset + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + torch.Tensor: for compatibility to RelPositionalEncoding + """ + self.pe = self.pe.to(x.place) + pos_emb = self.position_encoding(offset, x.size(1), False) + x = x * self.xscale + pos_emb + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding( + self, offset: Union[int, paddle.Tensor], size: int, apply_dropout: bool = True + ) -> paddle.Tensor: + """For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + if isinstance(offset, int): + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset : offset + size] + elif isinstance(offset, paddle.Tensor) and offset.dim() == 0: + assert offset + size <= self.max_len + pos_emb = self.pe[:, offset : offset + size] + else: + assert paddle.compat.max(offset) + size <= self.max_len + index = offset.unsqueeze(1) + paddle.arange(0, size).to(offset.place) + flag = index > 0 + index = index * flag + pos_emb = paddle.nn.functional.embedding(index, self.pe[0]) + if apply_dropout: + pos_emb = self.dropout(pos_emb) + return pos_emb + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward( + self, x: paddle.Tensor, offset: Union[int, paddle.Tensor] = 0 + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.pe = self.pe.to(x.place) + x = x * self.xscale + pos_emb = self.position_encoding(offset, x.size(1), False) + return self.dropout(x), self.dropout(pos_emb) + + +class WhisperPositionalEncoding(PositionalEncoding): + """Sinusoids position encoding used in openai-whisper.encoder""" + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): + super().__init__(d_model, dropout_rate, max_len) + self.xscale = 1.0 + log_timescale_increment = np.log(10000) / (d_model // 2 - 1) + inv_timescales = paddle.exp( + x=-log_timescale_increment * paddle.arange(d_model // 2) + ) + scaled_time = ( + paddle.arange(max_len)[:, np.newaxis] * inv_timescales[np.newaxis, :] + ) + pe = paddle.cat([paddle.sin(scaled_time), paddle.cos(scaled_time)], dim=1) + delattr(self, "pe") + self.register_buffer(name="pe", tensor=pe.unsqueeze(0)) + + +class LearnablePositionalEncoding(PositionalEncoding): + """Learnable position encoding used in openai-whisper.decoder""" + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): + super().__init__(d_model, dropout_rate, max_len) + self.pe = paddle.nn.parameter.Parameter(paddle.empty(1, max_len, d_model)) + self.xscale = 1.0 + + +class NoPositionalEncoding(paddle.nn.Layer): + """No position encoding""" + + def __init__(self, d_model: int, dropout_rate: float): + super().__init__() + self.d_model = d_model + self.dropout = paddle.nn.Dropout(p=dropout_rate) + + def forward( + self, x: paddle.Tensor, offset: Union[int, paddle.Tensor] = 0 + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Just return zero vector for interface compatibility""" + pos_emb = paddle.zeros(1, x.size(1), self.d_model).to(x.place) + return self.dropout(x), pos_emb + + def position_encoding( + self, offset: Union[int, paddle.Tensor], size: int + ) -> paddle.Tensor: + return paddle.zeros(1, size, self.d_model) + + +class EspnetRelPositionalEncoding(paddle.nn.Layer): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): + """Construct an PositionalEncoding object.""" + super(EspnetRelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = paddle.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(paddle.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: paddle.Tensor): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.place != x.place: + self.pe = self.pe.to(dtype=x.dtype, device=x.place) + return + pe_positive = paddle.zeros(x.size(1), self.d_model) + pe_negative = paddle.zeros(x.size(1), self.d_model) + position = paddle.arange(0, x.size(1), dtype=paddle.float32).unsqueeze(1) + div_term = paddle.exp( + x=paddle.arange(0, self.d_model, 2, dtype=paddle.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe_positive[:, 0::2] = paddle.sin(position * div_term) + pe_positive[:, 1::2] = paddle.cos(position * div_term) + pe_negative[:, 0::2] = paddle.sin(-1 * position * div_term) + pe_negative[:, 1::2] = paddle.cos(-1 * position * div_term) + pe_positive = paddle.flip(x=pe_positive, axis=[0]).unsqueeze(0) + pe_negative = pe_negative[1:].unsqueeze(0) + pe = paddle.cat([pe_positive, pe_negative], dim=1) + self.pe = pe.to(device=x.place, dtype=x.dtype) + + def forward( + self, x: paddle.Tensor, offset: Union[int, paddle.Tensor] = 0 + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.position_encoding(size=x.size(1), offset=offset) + return self.dropout(x), self.dropout(pos_emb) + + def position_encoding( + self, offset: Union[int, paddle.Tensor], size: int + ) -> paddle.Tensor: + """For getting encoding in a streaming fashion + + Attention!!!!! + we apply dropout only once at the whole utterance level in a none + streaming way, but will call this function several times with + increasing input size in a streaming scenario, so the dropout will + be applied several times. + + Args: + offset (int or torch.tensor): start offset + size (int): required size of position encoding + + Returns: + torch.Tensor: Corresponding encoding + """ + if isinstance(offset, int): + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - size + - offset + + 1 : self.pe.size(1) // 2 + + size + + offset, + ] + elif isinstance(offset, paddle.Tensor): + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - size + - offset + + 1 : self.pe.size(1) // 2 + + size + + offset, + ] + return pos_emb diff --git a/paddlespeech/t2s/modules/transformer/mask.py b/paddlespeech/t2s/modules/transformer/mask.py index 71dd37975..b0207b9fe 100644 --- a/paddlespeech/t2s/modules/transformer/mask.py +++ b/paddlespeech/t2s/modules/transformer/mask.py @@ -50,3 +50,106 @@ def target_mask(ys_in_pad, ignore_id, dtype=paddle.bool): ys_mask = ys_in_pad != ignore_id m = subsequent_mask(ys_mask.shape[-1]).unsqueeze(0) return ys_mask.unsqueeze(-2) & m + +def make_pad_mask(lengths: paddle.Tensor, max_len: int = 0) -> paddle.Tensor: + """Make mask tensor containing indices of padded part. + + See description of make_non_pad_mask. + + Args: + lengths (torch.Tensor): Batch of lengths (B,). + Returns: + torch.Tensor: Mask tensor containing indices of padded part. + + Examples: + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + """ + batch_size = lengths.shape[0] + max_len = max_len if max_len > 0 else lengths._max().item() + seq_range = paddle.arange(0, max_len, dtype=paddle.int32) + seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len]) + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + +def add_optional_chunk_mask( + xs: paddle.Tensor, + masks: paddle.Tensor, + use_dynamic_chunk: bool, + use_dynamic_left_chunk: bool, + decoding_chunk_size: int, + static_chunk_size: int, + num_decoding_left_chunks: int, + enable_full_context: bool = True, +): + """Apply optional mask for encoder. + + Args: + xs (torch.Tensor): padded input, (B, L, D), L for max length + mask (torch.Tensor): mask for xs, (B, 1, L) + use_dynamic_chunk (bool): whether to use dynamic chunk or not + use_dynamic_left_chunk (bool): whether to use dynamic left chunk for + training. + decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + static_chunk_size (int): chunk size for static chunk training/decoding + if it's greater than 0, if use_dynamic_chunk is true, + this parameter will be ignored + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + enable_full_context (bool): + True: chunk size is either [1, 25] or full context(max_len) + False: chunk size ~ U[1, 25] + + Returns: + torch.Tensor: chunk mask of the input xs. + """ + if use_dynamic_chunk: + max_len = xs.size(1) + if decoding_chunk_size < 0: + chunk_size = max_len + num_left_chunks = -1 + elif decoding_chunk_size > 0: + chunk_size = decoding_chunk_size + num_left_chunks = num_decoding_left_chunks + else: + chunk_size = paddle.randint(low=1, high=max_len, shape=(1,)).item() + num_left_chunks = -1 + if chunk_size > max_len // 2 and enable_full_context: + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + if use_dynamic_left_chunk: + max_left_chunks = (max_len - 1) // chunk_size + num_left_chunks = paddle.randint( + low=0, high=max_left_chunks, shape=(1,) + ).item() + chunk_masks = subsequent_chunk_mask( + xs.size(1), chunk_size, num_left_chunks, xs.place + ) + chunk_masks = chunk_masks.unsqueeze(0) + chunk_masks = masks & chunk_masks + elif static_chunk_size > 0: + num_left_chunks = num_decoding_left_chunks + chunk_masks = subsequent_chunk_mask( + xs.size(1), static_chunk_size, num_left_chunks, xs.place + ) + chunk_masks = chunk_masks.unsqueeze(0) + chunk_masks = masks & chunk_masks + else: + chunk_masks = masks + assert chunk_masks.dtype == paddle.bool + if (chunk_masks.sum(axis=-1) == 0).sum().item() != 0: + print( + "get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!" + ) + chunk_masks[chunk_masks.sum(axis=-1) == 0] = True + return chunk_masks \ No newline at end of file diff --git a/paddlespeech/t2s/modules/transformer/subsampling.py b/paddlespeech/t2s/modules/transformer/subsampling.py index a17278c0b..20220e8d6 100644 --- a/paddlespeech/t2s/modules/transformer/subsampling.py +++ b/paddlespeech/t2s/modules/transformer/subsampling.py @@ -15,10 +15,65 @@ """Subsampling layer definition.""" import paddle from paddle import nn - +from typing import Union from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding +class BaseSubsampling(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.right_context = 0 + self.subsampling_rate = 1 + + def position_encoding( + self, offset: Union[int, paddle.Tensor], size: int + ) -> paddle.Tensor: + return self.pos_enc.position_encoding(offset, size) +class LinearNoSubsampling(BaseSubsampling): + """Linear transform the input without subsampling + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + + """ + def __init__( + self, idim: int, odim: int, dropout_rate: float, pos_enc_class: paddle.nn.Layer + ): + """Construct an linear object.""" + super().__init__() + self.out = paddle.nn.Sequential( + paddle.nn.Linear(in_features=idim, out_features=odim), + paddle.nn.LayerNorm(normalized_shape=odim, epsilon=1e-05), + paddle.nn.Dropout(p=dropout_rate), + ) + self.pos_enc = pos_enc_class + self.right_context = 0 + self.subsampling_rate = 1 + + def forward( + self, + x: paddle.Tensor, + x_mask: paddle.Tensor, + offset: Union[int, paddle.Tensor] = 0, + ) -> tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Input x. + + Args: + x (paddle.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + paddle.Tensor: linear input tensor (#batch, time', odim), + where time' = time . + paddle.Tensor: linear input mask (#batch, 1, time'), + where time' = time . + + """ + x = self.out(x) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask class Conv2dSubsampling(nn.Layer): """Convolutional 2D subsampling (to 1/4 length). diff --git a/paddlespeech/t2s/modules/transformer/upsample_encoder.py b/paddlespeech/t2s/modules/transformer/upsample_encoder.py new file mode 100644 index 000000000..ea7550b27 --- /dev/null +++ b/paddlespeech/t2s/modules/transformer/upsample_encoder.py @@ -0,0 +1,352 @@ +import paddle + +"""Encoder definition.""" +from typing import Tuple +import paddle.nn.functional as F +from paddlespeech.t2s.modules.transformer.convolution import ConvolutionModule +from paddlespeech.t2s.modules.transformer.encoder_layer import ConformerEncoderLayer +from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward +from paddlespeech.t2s.models.CosyVoice.class_utils import (COSYVOICE_ACTIVATION_CLASSES, + COSYVOICE_ATTENTION_CLASSES, + COSYVOICE_EMB_CLASSES, + COSYVOICE_SUBSAMPLE_CLASSES) +from paddlespeech.t2s.modules.transformer.mask import add_optional_chunk_mask, make_pad_mask + + +class Upsample1D(paddle.nn.Layer): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + """ + + def __init__(self, channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.stride = stride + self.conv = paddle.nn.Conv1D( + self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0 + ) + + def forward( + self, inputs: paddle.Tensor, input_lengths: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + inputs = inputs.unsqueeze(2) + outputs = paddle.nn.functional.interpolate( + x=inputs, scale_factor=[1, float(self.stride)],mode="nearest" + ) + outputs = outputs.squeeze(2) + outputs = F.pad(outputs, [self.stride * 2, 0], value=0.0) + outputs = self.conv(outputs) + return outputs, input_lengths * self.stride + + +class PreLookaheadLayer(paddle.nn.Layer): + def __init__(self, channels: int, pre_lookahead_len: int = 1): + super().__init__() + self.channels = channels + self.pre_lookahead_len = pre_lookahead_len + self.conv1 = paddle.nn.Conv1D( + channels, channels, kernel_size=pre_lookahead_len + 1, stride=1, padding=0 + ) + self.conv2 = paddle.nn.Conv1D( + channels, channels, kernel_size=3, stride=1, padding=0 + ) + + def forward( + self, inputs: paddle.Tensor, context: paddle.Tensor = paddle.zeros([0, 0, 0]) + ) -> paddle.Tensor: + """ + inputs: (batch_size, seq_len, channels) + """ + outputs = paddle.transpose(inputs, perm=[0, 2, 1]).contiguous() + context = paddle.transpose(context, perm=[0, 2, 1]).contiguous() + + if context.shape[2] == 0: + outputs = F.pad( + outputs, [0, self.pre_lookahead_len], mode="constant", value=0.0 + ) + else: + assert ( + self.training is False + ), "you have passed context, make sure that you are running inference mode" + assert context.shape[2] == self.pre_lookahead_len + outputs = F.pad( + paddle.cat([outputs, context], dim=2), + [0, self.pre_lookahead_len - context.shape[2]], + mode="constant", + value=0.0, + ) + + outputs = paddle.nn.functional.leaky_relu(x=self.conv1(outputs)) + + outputs = F.pad( + outputs, [self.conv2._kernel_size[0] - 1, 0], mode="constant", value=0.0 + ) + outputs = self.conv2(outputs) + outputs = paddle.transpose(outputs, perm=[0, 2, 1]).contiguous() + + outputs = outputs + inputs + return outputs + + +class UpsampleConformerEncoder(paddle.nn.Layer): + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + pos_enc_layer_type: str = "rel_pos", + normalize_before: bool = True, + static_chunk_size: int = 0, + use_dynamic_chunk: bool = False, + global_cmvn: paddle.nn.Layer = None, + use_dynamic_left_chunk: bool = False, + positionwise_conv_kernel_size: int = 1, + macaron_style: bool = True, + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + cnn_module_kernel: int = 15, + causal: bool = False, + cnn_module_norm: str = "batch_norm", + key_bias: bool = True, + gradient_checkpointing: bool = False, + ): + """ + Args: + input_size (int): input dim + output_size (int): dimension of attention + attention_heads (int): the number of heads of multi head attention + linear_units (int): the hidden units number of position-wise feed + forward + num_blocks (int): the number of decoder blocks + dropout_rate (float): dropout rate + attention_dropout_rate (float): dropout rate in attention + positional_dropout_rate (float): dropout rate after adding + positional encoding + input_layer (str): input layer type. + optional [linear, conv2d, conv2d6, conv2d8] + pos_enc_layer_type (str): Encoder positional encoding layer type. + opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] + normalize_before (bool): + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + static_chunk_size (int): chunk size for static chunk training and + decoding + use_dynamic_chunk (bool): whether use dynamic chunk size for + training or not, You can only use fixed chunk(chunk_size > 0) + or dyanmic chunk size(use_dynamic_chunk = True) + global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module + use_dynamic_left_chunk (bool): whether use dynamic left chunk in + dynamic chunk training + key_bias: whether use bias in attention.linear_k, False for whisper models. + gradient_checkpointing: rerunning a forward-pass segment for each + checkpointed segment during backward. + """ + super().__init__() + self._output_size = output_size + self.global_cmvn = global_cmvn + self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( + input_size, + output_size, + dropout_rate, + COSYVOICE_EMB_CLASSES[pos_enc_layer_type]( + output_size, positional_dropout_rate + ), + ) + self.normalize_before = normalize_before + self.after_norm = paddle.nn.LayerNorm( + normalized_shape=output_size, epsilon=1e-05 + ) + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + self.gradient_checkpointing = gradient_checkpointing + activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]() + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + False, + ) + positionwise_layer_args = (output_size, linear_units, dropout_rate, activation) + convolution_layer_args = ( + output_size, + cnn_module_kernel, + activation, + cnn_module_norm, + causal, + ) + self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3) + self.encoders = paddle.nn.LayerList( + sublayers=[ + ConformerEncoderLayer( + output_size, + COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args + ), + PositionwiseFeedForward(*positionwise_layer_args), + PositionwiseFeedForward(*positionwise_layer_args) + if macaron_style + else None, + ConvolutionModule(*convolution_layer_args) + if use_cnn_module + else None, + dropout_rate, + normalize_before, + ) + for _ in range(num_blocks) + ] + ) + self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2) + self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( + input_size, + output_size, + dropout_rate, + COSYVOICE_EMB_CLASSES[pos_enc_layer_type]( + output_size, positional_dropout_rate + ), + ) + self.up_encoders = paddle.nn.LayerList( + sublayers=[ + ConformerEncoderLayer( + output_size, + COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( + *encoder_selfattn_layer_args + ), + PositionwiseFeedForward(*positionwise_layer_args), + PositionwiseFeedForward(*positionwise_layer_args) + if macaron_style + else None, + ConvolutionModule(*convolution_layer_args) + if use_cnn_module + else None, + dropout_rate, + normalize_before, + ) + for _ in range(4) + ] + ) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: paddle.Tensor, + xs_lens: paddle.Tensor, + context: paddle.Tensor = paddle.zeros([0, 0, 0]), + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + streaming: bool = False, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Embed positions in tensor. + + Args: + xs: padded input tensor (B, T, D) + xs_lens: input length (B) + decoding_chunk_size: decoding chunk size for dynamic chunk + 0: default for training, use random dynamic chunk. + <0: for decoding, use full chunk. + >0: for decoding, use fixed chunk size as set. + num_decoding_left_chunks: number of left chunks, this is for decoding, + the chunk size is decoding_chunk_size. + >=0: use num_decoding_left_chunks + <0: use all left chunks + Returns: + encoder output tensor xs, and subsampled masks + xs: padded output tensor (B, T' ~= T/subsample_rate, D) + masks: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + NOTE(xcsong): + We pass the `__call__` method of the modules instead of `forward` to the + checkpointing API because `__call__` attaches all the hooks of the module. + https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + """ + T = xs.shape[1] + + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + + if context.shape[1] != 0: + assert ( + self.training is False + ), "you have passed context, make sure that you are running inference mode" + context_masks = paddle.ones(1, 1, context.shape[1]).to(masks) + context, _, _ = self.embed(context, context_masks, offset=xs.shape[1]) + mask_pad = masks + chunk_masks = add_optional_chunk_mask( + xs, + masks, + False, + False, + 0, + self.static_chunk_size if streaming is True else 0, + -1, + ) + xs = self.pre_lookahead_layer(xs, context=context) + + + xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) + # + + xs = paddle.transpose(xs, perm=[0, 2, 1]).contiguous() + xs, xs_lens = self.up_layer(xs, xs_lens) + xs = paddle.transpose(xs, perm=[0, 2, 1]).contiguous() + T = xs.shape[1] + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) + xs, pos_emb, masks = self.up_embed(xs, masks) + mask_pad = masks + chunk_masks = add_optional_chunk_mask( + xs, + masks, + False, + False, + 0, + self.static_chunk_size * self.up_layer.stride if streaming is True else 0, + -1, + ) + xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) + if self.normalize_before: + xs = self.after_norm(xs) + + return xs, masks + + def forward_layers( + self, + xs: paddle.Tensor, + chunk_masks: paddle.Tensor, + pos_emb: paddle.Tensor, + mask_pad: paddle.Tensor, + ) -> paddle.Tensor: + for layer in self.encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs + + def forward_up_layers( + self, + xs: paddle.Tensor, + chunk_masks: paddle.Tensor, + pos_emb: paddle.Tensor, + mask_pad: paddle.Tensor, + ) -> paddle.Tensor: + for layer in self.up_encoders: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + return xs diff --git a/q.pdparams b/q.pdparams new file mode 100644 index 000000000..731d14119 Binary files /dev/null and b/q.pdparams differ diff --git a/q.pt b/q.pt new file mode 100644 index 000000000..4d315c18c Binary files /dev/null and b/q.pt differ